diff --git a/build.sh b/build.sh index 97bb7c0b..9b0ec759 100755 --- a/build.sh +++ b/build.sh @@ -2,6 +2,7 @@ set -e BUILD_DEEPEP_MODULE="ON" +BUILD_DEEPEP_OPS="ON" BUILD_KERNELS_MODULE="ON" BUILD_MEMORY_SAVER_MODULE="ON" @@ -20,6 +21,11 @@ while getopts ":a:hd" opt; do case "$OPTARG" in deepep ) BUILD_DEEPEP_MODULE="ON" + BUILD_DEEPEP_OPS="ON" + ;; + deepep2 ) + BUILD_DEEPEP_MODULE="ON" + BUILD_DEEPEP_OPS="OFF" ;; kernels ) BUILD_KERNELS_MODULE="ON" @@ -120,7 +126,11 @@ function build_deepep_kernels() if [[ "$ONLY_BUILD_DEEPEP_ADAPTER_MODULE" == "ON" ]]; then return 0; fi if [[ "$BUILD_DEEPEP_MODULE" != "ON" ]]; then return 0; fi - KERNEL_DIR="csrc/deepep/ops" + if [[ "$BUILD_DEEPEP_OPS" == "ON" ]]; then + KERNEL_DIR="csrc/deepep/ops" + else + KERNEL_DIR="csrc/deepep/ops2" + fi CUSTOM_OPP_DIR="${CURRENT_DIR}/python/deep_ep/deep_ep" cd "$KERNEL_DIR" || exit @@ -137,6 +147,7 @@ function build_deepep_kernels() echo "find run package: $custom_opp_file" chmod +x "$custom_opp_file" fi + rm -rf "$CUSTOM_OPP_DIR"/vendors ./build_out/custom_opp_*.run --install-path=$CUSTOM_OPP_DIR cd - } diff --git a/csrc/deepep/CMakeLists.txt b/csrc/deepep/CMakeLists.txt index 65628411..b5147573 100644 --- a/csrc/deepep/CMakeLists.txt +++ b/csrc/deepep/CMakeLists.txt @@ -1,6 +1,14 @@ # this is the cmakelist file for deepep build # deepep will be built as separated wheel package +if(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64") + set(DEEPEP_ARCH "x86_64") +elseif(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "aarch64") + set(DEEPEP_ARCH "aarch64") +else() + message(FATAL_ERROR "Unsupported host processor: ${CMAKE_SYSTEM_PROCESSOR}") +endif() + set(PROJECT_BUILD_PATH ${PROJECT_BINARY_DIR}) set(TARGET_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}) set(ASCEND_HOME_PATH ${ASCEND_HOME_PATH}) @@ -27,6 +35,7 @@ target_include_directories( deep_ep_cpp PRIVATE ${TORCH_NPU_DIR}/include/third_party/acl/inc/acl ${TORCH_NPU_DIR}/include/third_party/acl/inc ${ASCEND_HOME_PATH}/include + ${ASCEND_HOME_PATH}/${DEEPEP_ARCH}-linux/include/experiment/platform ) target_link_directories(deep_ep_cpp PRIVATE ${TORCH_DIR}/lib @@ -38,6 +47,7 @@ target_link_libraries(deep_ep_cpp PRIVATE ascendcl hccl torch_npu + opapi ) message(STATUS "TARGET_INSTALL_DIR = ${TARGET_INSTALL_DIR}") diff --git a/csrc/deepep/deep_ep.cpp b/csrc/deepep/deep_ep.cpp index 3718efbe..d843aaa9 100644 --- a/csrc/deepep/deep_ep.cpp +++ b/csrc/deepep/deep_ep.cpp @@ -1,7 +1,12 @@ #include #include +#include #include +#include +#include +#include + #include "hccl/hccl.h" #include "exception.hpp" #include "deep_ep.hpp" @@ -12,6 +17,11 @@ constexpr int PADDING_SIZE = 3; constexpr size_t HCOMM_NAME_LEN = 128; constexpr uint32_t NO_SCALES = 0; constexpr uint32_t DYNAMIC_SCALES = 2; +constexpr uint32_t MAX_BS = 4096; +constexpr int A3_MAX_HCCS_PEERS = 384; +constexpr int A2_MAX_HCCS_PEERS = 8; +constexpr int A2_MAX_BATCH_SIZE = 4096; +constexpr int A2_EXPERT_DATA_SIZE = 1 + 2 * A2_MAX_BATCH_SIZE; // 8193 Buffer::Buffer(int64_t rank, int64_t num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, std::string moe_all_to_all_group_name) @@ -37,6 +47,18 @@ Buffer::Buffer(int64_t rank, int64_t num_ranks, int64_t num_nvl_bytes, int64_t n } this->shared_expert_rank_num = get_value_from_env("MOE_SHARED_EXPERT_RANK_NUM", 0); + + num_rdma_ranks = 1; + num_nvl_ranks = num_ranks; + rdma_rank = rank; + nvl_rank = rank; + soc_version = op::GetCurrentPlatformInfo().GetSocVersion(); + if (soc_version == op::SocVersion::ASCEND910B) { + num_rdma_ranks = std::max(static_cast(1), num_ranks / A2_MAX_HCCS_PEERS); + num_nvl_ranks = std::min(num_ranks, static_cast(A2_MAX_HCCS_PEERS)); + rdma_rank = rank / A2_MAX_HCCS_PEERS; + nvl_rank = rank % A2_MAX_HCCS_PEERS; + } } Buffer::~Buffer() noexcept(false) {} @@ -46,6 +68,11 @@ bool Buffer::is_available() const return available; } +int Buffer::get_num_rdma_ranks() const +{ + return num_rdma_ranks; +} + std::tuple, torch::Tensor, torch::Tensor, std::optional> Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std::optional &previous_event, bool async, bool allocate_on_comm_stream) @@ -73,15 +100,24 @@ Buffer::get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std: const int num_tokens = new_topk_idx.size(0); const int num_topk = new_topk_idx.size(1); + const int local_ranksize = A2_MAX_HCCS_PEERS; + auto server_num = num_ranks / local_ranksize; auto device = new_topk_idx.device(); auto num_tokens_per_expert = at::zeros({num_experts}, at::dtype(at::kInt).device(device)); auto num_tokens_per_rank = at::zeros({num_ranks}, at::dtype(at::kInt).device(device)); auto is_token_in_rank = torch::empty({num_tokens, num_ranks}, at::dtype(at::kInt).device(device)); + const int total_size = num_experts * A2_EXPERT_DATA_SIZE + server_num + MAX_BS * (1 + 2 * server_num + num_experts); + auto total_data = at::zeros({total_size}, at::dtype(at::kInt).device(device)); + total_data + .index({at::indexing::Slice(num_experts + server_num + MAX_BS * (server_num + 1), + num_experts + server_num + MAX_BS * (server_num * 2 + num_experts + 1))}) + .fill_(-1); - EXEC_NPU_CMD(aclnnDispatchLayout, new_topk_idx, num_tokens, num_ranks, num_experts, num_topk, num_tokens_per_rank, - num_tokens_per_expert, is_token_in_rank); + EXEC_NPU_CMD(aclnnDispatchLayout, new_topk_idx, num_tokens, num_ranks, num_experts, num_topk, local_ranksize, + num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank, total_data); + this->send_data = total_data; std::optional num_tokens_per_rdma_rank = std::nullopt; std::optional output_event = std::nullopt; auto is_token_in_rank_bool = is_token_in_rank.to(at::kBool); @@ -401,6 +437,279 @@ Buffer::intranode_combine(const torch::Tensor &x, const torch::Tensor &topk_idx, return {combined_x, recv_topk_weights, event}; } +std::tuple, std::optional, std::optional, + std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, + std::optional> +Buffer::internode_dispatch( + const torch::Tensor &x, const std::optional &x_scales, const std::optional &topk_idx, + const std::optional &topk_weights, const std::optional &num_tokens_per_rank, + const std::optional &num_tokens_per_rdma_rank, const torch::Tensor &is_token_in_rank, + const std::optional &num_tokens_per_expert, const Config &config, + std::optional &previous_event, bool async, bool allocate_on_comm_stream, bool use_quant) +{ + // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. + EP_HOST_ASSERT(config.num_sms % 2 == 0); + int num_channels = config.num_sms / 2; + + at::Tensor new_x = x; + // for padding + if (topk_idx->size(0) < PADDING_SIZE) { + this->is_padding = true; + this->padding_cnt = PADDING_SIZE - topk_idx->size(0); + std::vector x_blocks; + if (topk_idx->size(0) != 0) { + x_blocks.emplace_back(x); + } else { + this->ori_x = x.clone(); + } + for (int i = 0; i < this->padding_cnt; i++) { + at::Tensor tmp_x = torch::ones({1, x.size(1)}, x.options()) * (i + 1) * 2; + x_blocks.emplace_back(tmp_x); + } + new_x = torch::cat(x_blocks, 0); + } + + // Type checks + EP_HOST_ASSERT(num_tokens_per_expert.has_value()); + EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == at::kInt); + // Shape and contiguous checks + EP_HOST_ASSERT(new_x.dim() == 2 and new_x.is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0); + + auto num_tokens = static_cast(new_x.size(0)), hidden = static_cast(new_x.size(1)); + auto num_experts = static_cast(num_tokens_per_expert->size(0)); + auto num_local_experts = static_cast(num_experts / num_ranks); + + // Top-k checks + int num_topk = 0; + EP_HOST_ASSERT(topk_idx.has_value()); + if (topk_idx.has_value()) { + num_topk = static_cast(topk_idx->size(1)); + EP_HOST_ASSERT(num_experts > 0); + EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous()); + EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(num_tokens == new_topk_idx.size(0)); + EP_HOST_ASSERT(num_topk == topk_weights->size(1)); + EP_HOST_ASSERT(topk_weights->scalar_type() == at::kFloat); + } + + auto device = x.device(); + at::Tensor new_topk_weights; + // for padding + if (topk_weights.has_value()) { + if (!this->is_padding) { + new_topk_weights = topk_weights.value(); + } else { + std::vector weight_blocks; + if (topk_weights->size(0) != 0) { + weight_blocks.emplace_back(topk_weights.value()); + } + for (int i = 0; i < this->padding_cnt; i++) { + at::Tensor tmp_weight = torch::arange(0, num_topk, topk_weights->options()).reshape({1, num_topk}); + weight_blocks.emplace_back(tmp_weight); + } + new_topk_weights = torch::cat(weight_blocks, 0); + } + } else { + new_topk_weights = at::ones({num_tokens, num_topk}, at::dtype(at::kFloat).device(device)); + } + + // FP8 scales checks + float *x_scales_ptr = nullptr; + int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0; + if (x_scales.has_value()) { + EP_HOST_ASSERT(new_x.element_size() == 1); + EP_HOST_ASSERT(x_scales->scalar_type() == at::kFloat or x_scales->scalar_type() == at::kInt); + EP_HOST_ASSERT(x_scales->dim() == 2); + EP_HOST_ASSERT(x_scales->size(0) == num_tokens); + num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); + x_scales_ptr = static_cast(x_scales->data_ptr()); + scale_token_stride = static_cast(x_scales->stride(0)); + scale_hidden_stride = static_cast(x_scales->stride(1)); + } + + int64_t local_rank_size = A2_MAX_HCCS_PEERS; + int32_t server_num = num_ranks / local_rank_size; + int64_t local_rank_id = rank % local_rank_size; + auto new_num_tokens_per_expert = num_tokens_per_expert.value(); + + auto new_send_data = this->send_data; + // 对应于layout的输出数据长度 + int64_t send_count = num_experts * A2_EXPERT_DATA_SIZE + server_num + MAX_BS * (1 + 2 * server_num + num_experts); + + auto send_data_offset = at::empty({num_experts}, at::dtype(at::kInt).device(x.device())); + at::Tensor tmp_data = + at::empty({send_count * num_ranks}, at::dtype(at::kInt).device(x.device())); // 给notify算子用来临时存数的空间 + at::Tensor recv_data = at::empty({send_count * num_ranks}, at::dtype(at::kInt).device(x.device())); + at::Tensor token_server_idx = + at::empty({MAX_BS, server_num}, at::dtype(at::kInt).device(x.device())); // offset_outer + at::Tensor token_unique_per_server = at::empty({server_num}, at::dtype(at::kInt).device(x.device())); + at::Tensor ep_rank_token_cnt = + at::empty({num_experts, num_ranks}, at::dtype(at::kInt).device(x.device())); // 包含全局的 + at::Tensor local_ep_token_cnt = + at::empty({num_local_experts, num_ranks}, at::dtype(at::kInt).device(x.device())); // 不是前缀和 + at::Tensor src_offset_rank_token_idx = + at::empty({num_experts, num_ranks, MAX_BS}, at::dtype(at::kInt).device(x.device())); + at::Tensor dst_offset_rank_token_idx = + at::empty({num_experts, num_ranks, MAX_BS}, at::dtype(at::kInt).device(x.device())); + at::Tensor offset_inner = at::empty({num_ranks, MAX_BS, num_experts}, at::dtype(at::kInt).device(x.device())); + at::Tensor count_outer = at::empty({MAX_BS}, at::dtype(at::kInt).device(x.device())); + at::Tensor expand_idx = at::empty({MAX_BS, num_experts}, at::dtype(at::kInt).device(x.device())); + + // get ep name + char hcom_ep_name[HCOMM_NAME_LEN]; + if (!moe_all_to_all_group_name.empty()) { + std::memcpy(hcom_ep_name, moe_all_to_all_group_name.data(), moe_all_to_all_group_name.size() + 1); + } else { + HCCL_CHECK(HcclGetCommName(ep_comm, hcom_ep_name)); + } + + EXEC_NPU_CMD(aclnnNotifyDispatchA2, new_send_data, new_num_tokens_per_expert, tmp_data, send_count, num_tokens, + num_topk, num_experts, + hcom_ep_name, // commGroup + num_ranks, // rankSize + rank, // rankId + local_rank_size, local_rank_id, + send_data_offset, // A2未使用 + recv_data, token_server_idx, token_unique_per_server, ep_rank_token_cnt, local_ep_token_cnt, + src_offset_rank_token_idx, dst_offset_rank_token_idx, offset_inner, count_outer, expand_idx); + + int total_recv_tokens = 0; + std::vector num_recv_tokens_per_expert_list; + + auto ep_rank_token_cnt_cpu = ep_rank_token_cnt.to(at::kCPU); + auto ep_rank_token_cnt_ptr = ep_rank_token_cnt_cpu.data_ptr(); + for (int local_e = 0; local_e < num_local_experts; ++local_e) { + int64_t local_expert_recv_tokens = 0; + for (int src_rank = 0; src_rank < num_ranks; ++src_rank) { + int64_t index = local_e * num_ranks + src_rank; + int recv_cnt = + ep_rank_token_cnt_ptr[(rank * num_local_experts * num_ranks) + local_e * num_ranks + src_rank]; + + total_recv_tokens += recv_cnt; + local_expert_recv_tokens += recv_cnt; + } + num_recv_tokens_per_expert_list.push_back(local_expert_recv_tokens); + } + int num_recv_tokens = (total_recv_tokens == 0) ? 1 : total_recv_tokens; + + int64_t tp_size = 1; + int64_t tp_rank = 0; + int64_t expertShardType = 0; + int64_t sharedExpertNum = 1; + int64_t sharedExpertRankNum = 0; + int64_t expertTokenNumsType = 0; + + int64_t quant_mode = use_quant ? DYNAMIC_SCALES : NO_SCALES; + int64_t global_bs = static_cast(MAX_BS * num_ranks); + at::Tensor expert_ids = new_topk_idx.to(at::kInt); + at::Tensor xActiveMask = at::empty({1}, at::dtype(at::kInt).device(x.device())); + + auto expandx_out = use_quant ? at::empty({num_recv_tokens, hidden}, at::dtype(at::kChar).device(x.device())) + : at::empty({num_recv_tokens, hidden}, x.options()); + auto dynamic_scales_out = at::empty({num_recv_tokens}, at::dtype(at::kFloat).device(x.device())); + auto expertTokenNums = at::zeros({1}, at::dtype(at::kLong).device(x.device())); + auto epRecvCount = at::zeros({1}, at::dtype(at::kInt).device(x.device())); + auto tpRecvCount = at::zeros({1}, at::dtype(at::kInt).device(x.device())); + auto expand_scales = at::empty({num_recv_tokens}, at::dtype(at::kFloat).device(x.device())); + at::Tensor dispatch_wait_recv_cost_stats_out; + + EXEC_NPU_CMD(aclnnDispatchNormalA2, new_x, expert_ids, x_scales, xActiveMask, new_topk_weights, token_server_idx, + token_unique_per_server, ep_rank_token_cnt, src_offset_rank_token_idx, dst_offset_rank_token_idx, + hcom_ep_name, num_ranks, rank, num_experts, hcom_ep_name, tp_size, tp_rank, expertShardType, + sharedExpertNum, sharedExpertRankNum, quant_mode, global_bs, expertTokenNumsType, expandx_out, + dynamic_scales_out, expand_idx, expertTokenNums, epRecvCount, expand_scales, + dispatch_wait_recv_cost_stats_out); + + auto recv_topk_idx = std::optional(); + auto recv_topk_weights = std::optional(); + if (topk_idx.has_value()) { + recv_topk_idx = at::empty({num_recv_tokens, num_topk}, topk_idx->options()); + recv_topk_weights = at::empty({num_recv_tokens, num_topk}, topk_weights->options()); + } + // Wait streams + std::optional event; + + return {expandx_out, + dynamic_scales_out, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens_per_expert_list, + expand_idx, + ep_rank_token_cnt, + offset_inner, + token_server_idx, + count_outer, + expand_scales, + event}; +} + +std::tuple, std::optional> Buffer::internode_combine( + const torch::Tensor &x, const torch::Tensor &topk_idx, const std::optional &topk_weights, + const torch::Tensor &src_idx, const torch::Tensor &send_head, const torch::Tensor &offsetInner, + const torch::Tensor &offsetOuter, const torch::Tensor &countOuter, const torch::Tensor &expand_scales) +{ + EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); + at::Tensor recv_x = x; + + at::Tensor topk_idx_p = topk_idx; + if (this->is_padding) { + topk_idx_p = this->new_topk_idx; + } + + auto topk_idx_int32 = topk_idx_p.to(at::kInt); + at::Tensor expert_ids = topk_idx_int32; + at::Tensor expand_idx = src_idx; // A2实现为扩维后的, [bs, k] --> [bs, num_expert], 实际算子未使用 + at::Tensor ep_send_counts = send_head; // A2需全局的, [num_expert, num_rank] + auto device = x.device(); + + const int num_tokens = topk_idx_p.size(0); + const int num_topk = topk_idx_p.size(1); + at::Tensor expert_scales = at::empty({1}, at::dtype(at::kFloat).device(x.device())); + + int64_t hidden = static_cast(recv_x.size(1)); + at::Tensor tp_send_counts = at::empty({1}, at::dtype(at::kInt).device(device)); + int64_t tp_world_size = 1; + int64_t tp_rankId = 0; + int64_t moe_expert_number = send_head.size(0); + int64_t global_bs = static_cast(MAX_BS * num_ranks); + + // get ep & tp name + char hcom_ep_name[HCOMM_NAME_LEN]; + if (!moe_all_to_all_group_name.empty()) { + std::memcpy(hcom_ep_name, moe_all_to_all_group_name.data(), moe_all_to_all_group_name.size() + 1); + } else { + HCCL_CHECK(HcclGetCommName(ep_comm, hcom_ep_name)); + } + + // Combine data + auto combined_x = torch::empty({new_topk_idx.size(0), hidden}, x.options()); + std::optional recv_topk_weights; + std::optional event; + at::Tensor x_active_mask, activation_scale, weight_scale, group_list; + int64_t expert_shared_type = 0; + int64_t out_dtype = 0; + int64_t comm_quant_mode = 0; + int64_t group_list_type = 0; + + EXEC_NPU_CMD(aclnnMoeDistributeCombineA2, recv_x, expert_ids, expand_idx, ep_send_counts, expert_scales, + tp_send_counts, x_active_mask, activation_scale, weight_scale, group_list, expand_scales, offsetInner, + offsetOuter, countOuter, hcom_ep_name, num_ranks, rank, moe_expert_number, hcom_ep_name, tp_world_size, + tp_rankId, expert_shared_type, shared_expert_num, shared_expert_rank_num, global_bs, out_dtype, + comm_quant_mode, group_list_type, combined_x); + + if (this->is_padding) { + if (this->padding_cnt == PADDING_SIZE) { + combined_x = this->ori_x; + } else { + combined_x = combined_x.slice(0, 0, PADDING_SIZE - this->padding_cnt); + } + is_padding = false; + } + return {combined_x, recv_topk_weights, event}; +} + std::tuple, at::Tensor, at::Tensor, at::Tensor, std::optional, std::optional>> Buffer::low_latency_dispatch(const at::Tensor &x, const at::Tensor &topk_idx, diff --git a/csrc/deepep/deep_ep.hpp b/csrc/deepep/deep_ep.hpp index a613f5dd..28d4d9b2 100644 --- a/csrc/deepep/deep_ep.hpp +++ b/csrc/deepep/deep_ep.hpp @@ -7,6 +7,7 @@ #include #include "hccl/hccl.h" #include "hccl/hccl_types.h" +#include "aclnn/opdev/platform.h" #include "config.hpp" #include "event.hpp" @@ -14,8 +15,9 @@ namespace deep_ep { struct Buffer { - int64_t rank, rdma_rank; - int64_t num_ranks; + int64_t rank, rdma_rank, nvl_rank; + int64_t num_ranks, num_rdma_ranks, num_nvl_ranks; + op::SocVersion soc_version; int64_t num_nvl_bytes; int64_t num_rdma_bytes; @@ -26,6 +28,7 @@ struct Buffer { at::Tensor ori_x; at::Tensor new_topk_idx; at::Tensor new_scales; + at::Tensor send_data; int64_t shared_expert_rank_num; int64_t shared_expert_num = 1; @@ -47,6 +50,10 @@ struct Buffer { bool is_available() const; + int get_num_rdma_ranks() const; + + int get_rdma_rank() const; + std::tuple, torch::Tensor, torch::Tensor, std::optional> get_dispatch_layout(const torch::Tensor &topk_idx, int num_experts, std::optional &previous_event, bool async, bool allocate_on_comm_stream); @@ -70,6 +77,22 @@ struct Buffer { const std::optional &topk_weights, const torch::Tensor &src_idx, const torch::Tensor &send_head, const std::optional &combine_send_cost_stats); + std::tuple, std::optional, std::optional, + std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, + torch::Tensor, std::optional> + internode_dispatch(const torch::Tensor &x, const std::optional &x_scales, + const std::optional &topk_idx, const std::optional &topk_weights, + const std::optional &num_tokens_per_rank, + const std::optional &num_tokens_per_rdma_rank, + const torch::Tensor &is_token_in_rank, const std::optional &num_tokens_per_expert, + const Config &config, std::optional &previous_event, bool async, + bool allocate_on_comm_stream, bool use_quant); + + std::tuple, std::optional> internode_combine( + const torch::Tensor &x, const torch::Tensor &topk_idx, const std::optional &topk_weights, + const torch::Tensor &src_idx, const torch::Tensor &send_head, const torch::Tensor &offsetInner, + const torch::Tensor &offsetOuter, const torch::Tensor &countOuter, const torch::Tensor &expand_scales); + std::tuple, at::Tensor, at::Tensor, at::Tensor, std::optional, std::optional>> low_latency_dispatch(const at::Tensor &x, const at::Tensor &topk_idx, @@ -77,8 +100,6 @@ struct Buffer { int64_t num_max_dispatch_tokens_per_rank, int64_t num_experts, bool use_fp8, bool round_scale, bool use_ue8m0, bool async, bool return_recv_hook); - int get_rdma_rank() const; - std::tuple, std::optional>> low_latency_combine( const at::Tensor &x, const at::Tensor &topk_idx, const at::Tensor &topk_weights, const at::Tensor &src_info, const at::Tensor &layout_range, int64_t num_max_dispatch_tokens_per_rank, int64_t num_experts, diff --git a/csrc/deepep/ops2/CMakeLists.txt b/csrc/deepep/ops2/CMakeLists.txt new file mode 100644 index 00000000..359beca2 --- /dev/null +++ b/csrc/deepep/ops2/CMakeLists.txt @@ -0,0 +1,67 @@ +cmake_minimum_required(VERSION 3.16.0) +project(opp) + +if(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64") + set(CANN_HOST_ARCH "x86_64") +elseif(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "aarch64") + set(CANN_HOST_ARCH "aarch64") +else() + message(FATAL_ERROR "Unsupported host processor: ${CMAKE_SYSTEM_PROCESSOR}") +endif() + +include(cmake/config.cmake) +include(cmake/func.cmake) +include(cmake/intf.cmake) + +set(CMAKE_COMPILE ${CMAKE_CXX_COMPILER}) + +if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/framework) + add_subdirectory(framework) +endif() +if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/op_host) + add_subdirectory(op_host) +endif() +if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/op_kernel) + add_subdirectory(op_kernel) +endif() +if(ENABLE_TEST AND EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/testcases) + add_subdirectory(testcases) +endif() + +add_custom_command(OUTPUT ${CMAKE_BINARY_DIR}/scripts/install.sh ${CMAKE_BINARY_DIR}/scripts/upgrade.sh + COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_BINARY_DIR}/scripts + COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_SOURCE_DIR}/scripts/ ${CMAKE_BINARY_DIR}/scripts/ + COMMAND sed -i "s/vendor_name=customize/vendor_name=${vendor_name}/g" ${CMAKE_BINARY_DIR}/scripts/install.sh ${CMAKE_BINARY_DIR}/scripts/upgrade.sh + VERBATIM +) +add_custom_target(modify_vendor ALL DEPENDS ${CMAKE_BINARY_DIR}/scripts/install.sh ${CMAKE_BINARY_DIR}/scripts/upgrade.sh) + +get_system_info(SYSTEM_INFO) + +# gen version.info +add_custom_target(gen_version_info ALL + COMMAND bash ${CMAKE_CURRENT_SOURCE_DIR}/cmake/util/gen_version_info.sh ${ASCEND_CANN_PACKAGE_PATH} ${CMAKE_CURRENT_BINARY_DIR} +) + +if(NOT ASCEND_PACK_SHARED_LIBRARY) + install(DIRECTORY ${CMAKE_BINARY_DIR}/scripts/ DESTINATION . FILE_PERMISSIONS OWNER_EXECUTE OWNER_READ GROUP_READ) + + install(FILES ${CMAKE_SOURCE_DIR}/custom.proto DESTINATION packages OPTIONAL) + install(FILES ${CMAKE_CURRENT_BINARY_DIR}/version.info + DESTINATION packages/vendors/${vendor_name}/) + + # CPack config + set(CPACK_PACKAGE_NAME ${CMAKE_PROJECT_NAME}) + set(CPACK_PACKAGE_VERSION ${CMAKE_PROJECT_VERSION}) + set(CPACK_PACKAGE_DESCRIPTION "CPack opp project") + set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "CPack opp project") + set(CPACK_PACKAGE_DIRECTORY ${CMAKE_INSTALL_PREFIX}) + + set(CPACK_PACKAGE_FILE_NAME "custom_opp_${SYSTEM_INFO}_${CMAKE_SYSTEM_PROCESSOR}.run") + set(CPACK_GENERATOR External) + set(CPACK_CMAKE_GENERATOR "Unix Makefiles") + set(CPACK_EXTERNAL_ENABLE_STAGING TRUE) + set(CPACK_EXTERNAL_PACKAGE_SCRIPT ${CMAKE_SOURCE_DIR}/cmake/makeself.cmake) + set(CPACK_EXTERNAL_BUILT_PACKAGES ${CPACK_PACKAGE_DIRECTORY}/_CPack_Packages/Linux/External/${CPACK_PACKAGE_FILE_NAME}/${CPACK_PACKAGE_FILE_NAME}) + include(CPack) +endif() diff --git a/csrc/deepep/ops2/CMakePresets.json b/csrc/deepep/ops2/CMakePresets.json new file mode 100644 index 00000000..f87df921 --- /dev/null +++ b/csrc/deepep/ops2/CMakePresets.json @@ -0,0 +1,59 @@ +{ + "version": 1, + "cmakeMinimumRequired": { + "major": 3, + "minor": 19, + "patch": 0 + }, + "configurePresets": [ + { + "name": "default", + "displayName": "Default Config", + "description": "Default build using Unix Makefiles generator for native compilation", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build_out", + "cacheVariables": { + "CMAKE_BUILD_TYPE": { + "type": "STRING", + "value": "Release" + }, + "ENABLE_SOURCE_PACKAGE": { + "type": "BOOL", + "value": "True" + }, + "ENABLE_BINARY_PACKAGE": { + "type": "BOOL", + "value": "True" + }, + "ASCEND_COMPUTE_UNIT": { + "type": "STRING", + "value": "ascend910b" + }, + "ENABLE_TEST": { + "type": "BOOL", + "value": "True" + }, + "vendor_name": { + "type": "STRING", + "value": "hwcomputing" + }, + "ASCEND_CANN_PACKAGE_PATH": { + "type": "PATH", + "value": "/usr/local/Ascend/ascend-toolkit/latest" + }, + "ASCEND_PYTHON_EXECUTABLE": { + "type": "STRING", + "value": "python3" + }, + "CMAKE_INSTALL_PREFIX": { + "type": "PATH", + "value": "${sourceDir}/build_out" + }, + "ASCEND_PACK_SHARED_LIBRARY": { + "type": "BOOL", + "value": "False" + } + } + } + ] +} diff --git a/csrc/deepep/ops2/README.md b/csrc/deepep/ops2/README.md new file mode 100644 index 00000000..2856efe8 --- /dev/null +++ b/csrc/deepep/ops2/README.md @@ -0,0 +1,10 @@ +# moe_dispatch_combine + +# compile +bash build.sh + +# install +./build_out/custom_opp_ubuntu_aarch64.run --install-path=/usr/local/Ascend/ascend-toolkit/latest/opp/ + +# require import env parameters before running +source /usr/local/Ascend/ascend-toolkit/latest/opp/vendors/hwcomputing/bin/set_env.bash diff --git a/csrc/deepep/ops2/build.sh b/csrc/deepep/ops2/build.sh new file mode 100755 index 00000000..7692e4e5 --- /dev/null +++ b/csrc/deepep/ops2/build.sh @@ -0,0 +1,69 @@ +#!/bin/bash + +export OPS_PROJECT_NAME=aclnnInner + +SCRIPTS_DIR=$(cd "$(dirname "$0")" && pwd) + +if [ -n "$BASE_LIBS_PATH" ]; then + export ASCEND_HOME_PATH="$BASE_LIBS_PATH" +elif [ -z "$ASCEND_HOME_PATH" ]; then + if [ -n "$ASCEND_AICPU_PATH" ]; then + export ASCEND_HOME_PATH="$ASCEND_AICPU_PATH" + else + echo "please set env." >&2 + exit 1 + fi +fi +echo "using ASCEND_HOME_PATH: $ASCEND_HOME_PATH" +script_path=$(realpath $(dirname $0)) + +BUILD_DIR="build_out" +HOST_NATIVE_DIR="host_native_tiling" + +chmod +x cmake/util/gen_ops_filter.sh +mkdir -p build_out +rm -rf build_out/* + +opts=$(python3 $script_path/cmake/util/preset_parse.py $script_path/CMakePresets.json) +ENABLE_CROSS="-DENABLE_CROSS_COMPILE=True" +ENABLE_BINARY="-DENABLE_BINARY_PACKAGE=True" +ENABLE_LIBRARY="-DASCEND_PACK_SHARED_LIBRARY=True" +cmake_version=$(cmake --version | grep "cmake version" | awk '{print $3}') + +target=package +if [ -n "$1" ]; then target="$1"; fi +if [[ $opts =~ $ENABLE_LIBRARY ]]; then target=install; fi + +if [[ $opts =~ $ENABLE_CROSS ]] && [[ $opts =~ $ENABLE_BINARY ]] +then + if [ "$cmake_version" \< "3.19.0" ] ; then + cmake -S . -B "$BUILD_DIR" $opts -DENABLE_CROSS_COMPILE=0 + else + cmake -S . -B "$BUILD_DIR" --preset=default -DENABLE_CROSS_COMPILE=0 + fi + cmake --build "$BUILD_DIR" --target cust_optiling + mkdir $BUILD_DIR/$HOST_NATIVE_DIR + lib_path=$(find "$BUILD_DIR" -name "libcust_opmaster_rt2.0.so") + if [ -z "$lib_path" ] || [ $(echo "$lib_path" | wc -l) -ne 1 ]; then + echo "Error: Expected to find exactly one libcust_opmaster_rt2.0.so, but found none or multiple." >&2 + exit 1 + fi + mv "$lib_path" "$BUILD_DIR/$HOST_NATIVE_DIR/" + find "$BUILD_DIR" -mindepth 1 -maxdepth 1 ! -name "$HOST_NATIVE_DIR" -exec rm -rf {} + + host_native_tiling_lib=$(realpath $(find $BUILD_DIR -type f -name "libcust_opmaster_rt2.0.so")) + if [ "$cmake_version" \< "3.19.0" ] ; then + cmake -S . -B "$BUILD_DIR" $opts -DHOST_NATIVE_TILING_LIB=$host_native_tiling_lib + else + cmake -S . -B "$BUILD_DIR" --preset=default -DHOST_NATIVE_TILING_LIB=$host_native_tiling_lib + fi + cmake --build "$BUILD_DIR" --target binary -j$(nproc) + cmake --build "$BUILD_DIR" --target $target -j$(nproc) +else + if [ "$cmake_version" \< "3.19.0" ] ; then + cmake -S . -B "$BUILD_DIR" $opts + else + cmake -S . -B "$BUILD_DIR" --preset=default + fi + cmake --build "$BUILD_DIR" --target binary -j$(nproc) + cmake --build "$BUILD_DIR" --target $target -j$(nproc) +fi diff --git a/csrc/deepep/ops2/cmake/config.cmake b/csrc/deepep/ops2/cmake/config.cmake new file mode 100755 index 00000000..d990cd31 --- /dev/null +++ b/csrc/deepep/ops2/cmake/config.cmake @@ -0,0 +1,42 @@ + +set(CMAKE_CXX_FLAGS_DEBUG "") +set(CMAKE_CXX_FLAGS_RELEASE "") + +if (NOT DEFINED CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release CACHE STRING "") +endif() +if (CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) + set(CMAKE_INSTALL_PREFIX "${CMAKE_SOURCE_DIR}/build_out" CACHE PATH "" FORCE) +endif() +if (NOT DEFINED ASCEND_CANN_PACKAGE_PATH) + set(ASCEND_CANN_PACKAGE_PATH /usr/local/Ascend/latest CACHE PATH "") +endif() +if (NOT DEFINED ASCEND_PYTHON_EXECUTABLE) + set(ASCEND_PYTHON_EXECUTABLE python3 CACHE STRING "") +endif() +if (NOT DEFINED ASCEND_COMPUTE_UNIT) + set(ASCEND_COMPUTE_UNIT ascend910_93 CACHE STRING "") +endif() +if (NOT DEFINED ENABLE_TEST) + set(ENABLE_TEST FALSE CACHE BOOL "") +endif() +if (NOT DEFINED ENABLE_CROSS_COMPILE) + set(ENABLE_CROSS_COMPILE FALSE CACHE BOOL "") +endif() +if (NOT DEFINED CMAKE_CROSS_PLATFORM_COMPILER) + set(CMAKE_CROSS_PLATFORM_COMPILER "/your/cross/compiler/path" CACHE PATH "") +endif() +if (NOT DEFINED CMAKE_CROSS_LIBRARY_PATH) + set(CMAKE_CROSS_LIBRARY_PATH "" CACHE PATH "") +endif() +if (NOT DEFINED ASCEND_PACK_SHARED_LIBRARY) + set(ASCEND_PACK_SHARED_LIBRARY False CACHE BOOL "") +endif() +set(ASCEND_TENSOR_COMPILER_PATH ${ASCEND_CANN_PACKAGE_PATH}/compiler) +set(ASCEND_CCEC_COMPILER_PATH ${ASCEND_TENSOR_COMPILER_PATH}/ccec_compiler/bin) +set(ASCEND_AUTOGEN_PATH ${CMAKE_BINARY_DIR}/autogen) +file(MAKE_DIRECTORY ${ASCEND_AUTOGEN_PATH}) +set(CUSTOM_COMPILE_OPTIONS "custom_compile_options.ini") +set(CUSTOM_OPC_OPTIONS "custom_opc_options.ini") +file(WRITE ${ASCEND_AUTOGEN_PATH}/${CUSTOM_COMPILE_OPTIONS} "") +file(WRITE ${ASCEND_AUTOGEN_PATH}/${CUSTOM_OPC_OPTIONS} "") diff --git a/csrc/deepep/ops2/cmake/device_task.cmake b/csrc/deepep/ops2/cmake/device_task.cmake new file mode 100755 index 00000000..3b7c0a13 --- /dev/null +++ b/csrc/deepep/ops2/cmake/device_task.cmake @@ -0,0 +1,48 @@ +message(STATUS "TILING SINK TASK BEGIN") +message(STATUS "TARGET: ${TARGET}") +message(STATUS "OPTION: ${OPTION}") +message(STATUS "SRC: ${SRC}") +message(STATUS "VENDOR: ${VENDOR_NAME}") + +set(CMAKE_CXX_COMPILER ${ASCEND_CANN_PACKAGE_PATH}/toolkit/toolchain/hcc/bin/aarch64-target-linux-gnu-g++) +set(CMAKE_C_COMPILER ${ASCEND_CANN_PACKAGE_PATH}/toolkit/toolchain/hcc/bin/aarch64-target-linux-gnu-gcc) + +string(REPLACE " " ";" SRC "${SRC}") +add_library(${TARGET} ${OPTION} + ${SRC} +) +target_compile_definitions(${TARGET} PRIVATE + DEVICE_OP_TILING_LIB + _FORTIFY_SOURCE=2 + google=ascend_private +) +target_include_directories(${TARGET} PRIVATE + ${ASCEND_CANN_PACKAGE_PATH}/include +) +target_compile_options(${TARGET} PRIVATE + -fPIC + -fstack-protector-strong + -fstack-protector-all + -O2 + -std=c++11 + -fvisibility-inlines-hidden + -fvisibility=hidden +) +target_link_libraries(${TARGET} PRIVATE + -Wl,--whole-archive + device_register + c_sec + mmpa + tiling_api + platform_static + ascend_protobuf + exe_meta_device + -Wl,--no-whole-archive +) +target_link_directories(${TARGET} PRIVATE + ${ASCEND_CANN_PACKAGE_PATH}/lib64/device/lib64 + ${ASCEND_CANN_PACKAGE_PATH}/compiler/lib64 +) +set_target_properties(${TARGET} PROPERTIES + OUTPUT_NAME cust_opmaster +) diff --git a/csrc/deepep/ops2/cmake/func.cmake b/csrc/deepep/ops2/cmake/func.cmake new file mode 100755 index 00000000..aff39581 --- /dev/null +++ b/csrc/deepep/ops2/cmake/func.cmake @@ -0,0 +1,368 @@ +include(ExternalProject) + +function(get_system_info SYSTEM_INFO) + if (UNIX) + execute_process(COMMAND grep -i ^id= /etc/os-release OUTPUT_VARIABLE TEMP) + string(REGEX REPLACE "\n|id=|ID=|\"" "" SYSTEM_NAME ${TEMP}) + set(${SYSTEM_INFO} ${SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR} PARENT_SCOPE) + elseif (WIN32) + message(STATUS "System is Windows. Only for pre-build.") + else () + message(FATAL_ERROR "${CMAKE_SYSTEM_NAME} not support.") + endif () +endfunction() + +function(opbuild) + message(STATUS "Opbuild generating sources") + cmake_parse_arguments(OPBUILD "" "OUT_DIR;PROJECT_NAME;ACCESS_PREFIX;ENABLE_SOURCE" "OPS_SRC" ${ARGN}) + execute_process(COMMAND ${CMAKE_COMPILE} -g -fPIC -shared -std=c++11 ${OPBUILD_OPS_SRC} -D_GLIBCXX_USE_CXX11_ABI=0 + -I ${ASCEND_CANN_PACKAGE_PATH}/include -I ${CMAKE_CURRENT_SOURCE_DIR}/../op_kernel + -L ${ASCEND_CANN_PACKAGE_PATH}/lib64 -lexe_graph -lregister -ltiling_api + -o ${OPBUILD_OUT_DIR}/libascend_all_ops.so + RESULT_VARIABLE EXEC_RESULT + OUTPUT_VARIABLE EXEC_INFO + ERROR_VARIABLE EXEC_ERROR + ) + if (${EXEC_RESULT}) + message("build ops lib info: ${EXEC_INFO}") + message("build ops lib error: ${EXEC_ERROR}") + message(FATAL_ERROR "opbuild run failed!") + endif() + set(proj_env "") + set(prefix_env "") + if (NOT "${OPBUILD_PROJECT_NAME}x" STREQUAL "x") + set(proj_env "OPS_PROJECT_NAME=${OPBUILD_PROJECT_NAME}") + endif() + if (NOT "${OPBUILD_ACCESS_PREFIX}x" STREQUAL "x") + set(prefix_env "OPS_DIRECT_ACCESS_PREFIX=${OPBUILD_ACCESS_PREFIX}") + endif() + + set(ENV{ENABLE_SOURCE_PACKAGE} ${OPBUILD_ENABLE_SOURCE}) + if(${ASCEND_PACK_SHARED_LIBRARY}) + if (NOT vendor_name) + message(FATAL_ERROR "ERROR: vendor_name is invalid!") + return() + endif() + set(ENV{ASCEND_VENDOR_NAME} ${vendor_name}) + set(ENV{OPS_PRODUCT_NAME} ${ASCEND_COMPUTE_UNIT}) + set(ENV{SYSTEM_PROCESSOR} ${CMAKE_SYSTEM_PROCESSOR}) + endif() + execute_process(COMMAND ${proj_env} ${prefix_env} ${ASCEND_CANN_PACKAGE_PATH}/toolkit/tools/opbuild/op_build + ${OPBUILD_OUT_DIR}/libascend_all_ops.so ${OPBUILD_OUT_DIR} + RESULT_VARIABLE EXEC_RESULT + OUTPUT_VARIABLE EXEC_INFO + ERROR_VARIABLE EXEC_ERROR + ) + unset(ENV{ENABLE_SOURCE_PACKAGE}) + if(${ASCEND_PACK_SHARED_LIBRARY}) + unset(ENV{ASCEND_VENDOR_NAME}) + unset(ENV{OPS_PRODUCT_NAME}) + unset(ENV{SYSTEM_PROCESSOR}) + endif() + if (${EXEC_RESULT}) + message("opbuild ops info: ${EXEC_INFO}") + message("opbuild ops error: ${EXEC_ERROR}") + endif() + message(STATUS "Opbuild generating sources - done") +endfunction() + +function(add_ops_info_target) + cmake_parse_arguments(OPINFO "" "TARGET;OPS_INFO;OUTPUT;INSTALL_DIR" "" ${ARGN}) + get_filename_component(opinfo_file_path "${OPINFO_OUTPUT}" DIRECTORY) + add_custom_command(OUTPUT ${OPINFO_OUTPUT} + COMMAND mkdir -p ${opinfo_file_path} + COMMAND ${ASCEND_PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/cmake/util/parse_ini_to_json.py + ${OPINFO_OPS_INFO} ${OPINFO_OUTPUT} + ) + add_custom_target(${OPINFO_TARGET} ALL + DEPENDS ${OPINFO_OUTPUT} + ) + if(NOT ${ASCEND_PACK_SHARED_LIBRARY}) + install(FILES ${OPINFO_OUTPUT} + DESTINATION ${OPINFO_INSTALL_DIR} + ) + endif() +endfunction() + +function(add_ops_compile_options OP_TYPE) + cmake_parse_arguments(OP_COMPILE "" "OP_TYPE" "COMPUTE_UNIT;OPTIONS" ${ARGN}) + execute_process(COMMAND ${ASCEND_PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/cmake/util/ascendc_gen_options.py + ${ASCEND_AUTOGEN_PATH}/${CUSTOM_COMPILE_OPTIONS} ${OP_TYPE} ${OP_COMPILE_COMPUTE_UNIT} + ${OP_COMPILE_OPTIONS} + RESULT_VARIABLE EXEC_RESULT + OUTPUT_VARIABLE EXEC_INFO + ERROR_VARIABLE EXEC_ERROR) + if (${EXEC_RESULT}) + message("add ops compile options info: ${EXEC_INFO}") + message("add ops compile options error: ${EXEC_ERROR}") + message(FATAL_ERROR "add ops compile options failed!") + endif() +endfunction() + +function(add_npu_support_target) + cmake_parse_arguments(NPUSUP "" "TARGET;OPS_INFO_DIR;OUT_DIR;INSTALL_DIR" "" ${ARGN}) + get_filename_component(npu_sup_file_path "${NPUSUP_OUT_DIR}" DIRECTORY) + add_custom_command(OUTPUT ${NPUSUP_OUT_DIR}/npu_supported_ops.json + COMMAND mkdir -p ${NPUSUP_OUT_DIR} + COMMAND ${CMAKE_SOURCE_DIR}/cmake/util/gen_ops_filter.sh + ${NPUSUP_OPS_INFO_DIR} + ${NPUSUP_OUT_DIR} + ) + add_custom_target(npu_supported_ops ALL + DEPENDS ${NPUSUP_OUT_DIR}/npu_supported_ops.json + ) + if(NOT ${ASCEND_PACK_SHARED_LIBRARY}) + install(FILES ${NPUSUP_OUT_DIR}/npu_supported_ops.json + DESTINATION ${NPUSUP_INSTALL_DIR} + ) + endif() +endfunction() + +function(add_simple_kernel_compile) + set(options "") + set(single_value_args "OPS_INFO;OUT_DIR;TILING_LIB;OP_TYPE;SRC;COMPUTE_UNIT;JSON_FILE;DYNAMIC_PATH") + set(multi_value_args "OPTIONS;CONFIGS") + cmake_parse_arguments(BINCMP "${options}" "${single_value_args}" "${multi_value_args}" ${ARGN}) + if (NOT DEFINED BINCMP_OUT_DIR) + set(BINCMP_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/binary) + endif() + if (NOT DEFINED BINCMP_TILING_LIB) + set(BINCMP_TILING_LIB $) + endif() + if (${ASCEND_PACK_SHARED_LIBRARY}) + if (NOT TARGET op_kernel_pack) + add_custom_target(op_kernel_pack + COMMAND ${ASCEND_PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/cmake/util/ascendc_pack_kernel.py + --input-path=${BINCMP_OUT_DIR} + --output-path=${BINCMP_OUT_DIR}/library + --enable-library=${ASCEND_PACK_SHARED_LIBRARY} + --platform=${CMAKE_SYSTEM_PROCESSOR}) + add_library(ascend_kernels INTERFACE) + target_link_libraries(ascend_kernels INTERFACE kernels) + target_link_directories(ascend_kernels INTERFACE ${BINCMP_OUT_DIR}/library) + target_include_directories(ascend_kernels INTERFACE ${BINCMP_OUT_DIR}/library) + add_dependencies(ascend_kernels op_kernel_pack) + add_dependencies(op_kernel_pack ${BINCMP_OP_TYPE}_${BINCMP_COMPUTE_UNIT}) + endif() + endif() + # add Environment Variable Configurations of ccache + set(_ASCENDC_ENV_VAR) + if(${CMAKE_CXX_COMPILER_LAUNCHER} MATCHES "ccache$") + list(APPEND _ASCENDC_ENV_VAR export ASCENDC_CCACHE_EXECUTABLE=${CMAKE_CXX_COMPILER_LAUNCHER} &&) + endif() + + if (NOT DEFINED BINCMP_OPS_INFO) + set(BINCMP_OPS_INFO ${ASCEND_AUTOGEN_PATH}/aic-${BINCMP_COMPUTE_UNIT}-ops-info.ini) + endif() + if (NOT ${ENABLE_CROSS_COMPILE}) + add_custom_target(${BINCMP_OP_TYPE}_${BINCMP_COMPUTE_UNIT} + COMMAND ${_ASCENDC_ENV_VAR} ${ASCEND_PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/cmake/util/ascendc_compile_kernel.py + --op-name=${BINCMP_OP_TYPE} + --src-file=${BINCMP_SRC} + --compute-unit=${BINCMP_COMPUTE_UNIT} + --compile-options=\"${BINCMP_OPTIONS}\" + --debug-config=\"${BINCMP_CONFIGS}\" + --config-ini=${BINCMP_OPS_INFO} + --tiling-lib=${BINCMP_TILING_LIB} + --output-path=${BINCMP_OUT_DIR} + --dynamic-dir=${BINCMP_DYNAMIC_PATH} + --enable-binary=\"${ENABLE_BINARY_PACKAGE}\" + --json-file=${BINCMP_JSON_FILE} + --build-tool=$(MAKE)) + add_dependencies(${BINCMP_OP_TYPE}_${BINCMP_COMPUTE_UNIT} cust_optiling) + else() + if (${ENABLE_BINARY_PACKAGE} AND NOT DEFINED HOST_NATIVE_TILING_LIB) + message(FATAL_ERROR "Native host libs was not set for cross compile!") + endif() + add_custom_target(${BINCMP_OP_TYPE}_${BINCMP_COMPUTE_UNIT} + COMMAND ${_ASCENDC_ENV_VAR} ${ASCEND_PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/cmake/util/ascendc_compile_kernel.py + --op-name=${BINCMP_OP_TYPE} + --src-file=${BINCMP_SRC} + --compute-unit=${BINCMP_COMPUTE_UNIT} + --compile-options=\"${BINCMP_OPTIONS}\" + --debug-config=\"${BINCMP_CONFIGS}\" + --config-ini=${BINCMP_OPS_INFO} + --tiling-lib=${HOST_NATIVE_TILING_LIB} + --output-path=${BINCMP_OUT_DIR} + --dynamic-dir=${BINCMP_DYNAMIC_PATH} + --enable-binary=\"${ENABLE_BINARY_PACKAGE}\" + --json-file=${BINCMP_JSON_FILE} + --build-tool=$(MAKE)) + endif() + add_dependencies(ascendc_bin_${BINCMP_COMPUTE_UNIT}_gen_ops_config ${BINCMP_OP_TYPE}_${BINCMP_COMPUTE_UNIT}) + add_dependencies(${BINCMP_OP_TYPE}_${BINCMP_COMPUTE_UNIT} ops_info_gen_${BINCMP_COMPUTE_UNIT}) +endfunction() + +function(ascendc_device_library) + message(STATUS "Ascendc device library generating") + cmake_parse_arguments(DEVICE "" "TARGET;OPTION" "SRC" ${ARGN}) + execute_process( + COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_CURRENT_BINARY_DIR}/tiling_sink + COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/tiling_sink/CMakeLists.txt + ) + execute_process( + COMMAND ${CMAKE_COMMAND} -E echo "cmake_minimum_required(VERSION 3.16.0)\nproject(cust_tiling_sink)\ninclude(${CMAKE_SOURCE_DIR}/cmake/device_task.cmake)\n" + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/tiling_sink/CMakeLists.txt + RESULT_VARIABLE result + ) + string(REPLACE ";" " " DEVICE_SRC "${DEVICE_SRC}") + ExternalProject_Add(tiling_sink_task + SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/tiling_sink + CONFIGURE_COMMAND ${CMAKE_COMMAND} + -DASCEND_CANN_PACKAGE_PATH=${ASCEND_CANN_PACKAGE_PATH} + -DTARGET=${DEVICE_TARGET} + -DOPTION=${DEVICE_OPTION} + -DSRC=${DEVICE_SRC} + -DVENDOR_NAME=${vendor_name} + + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX} + INSTALL_COMMAND "" + BUILD_ALWAYS TRUE + ) + ExternalProject_Get_Property(tiling_sink_task BINARY_DIR) + set(TILINGSINK_LIB_PATH "") + if ("${DEVICE_OPTION}" STREQUAL "SHARED") + set(TILINGSINK_LIB_PATH "${BINARY_DIR}/libcust_opmaster.so") + else() + set(TILINGSINK_LIB_PATH "${BINARY_DIR}/libcust_opmaster.a") + endif() + install(FILES ${TILINGSINK_LIB_PATH} + DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/op_master_device/lib + ) +endfunction() +function(add_opregistry_target) + string(REPLACE ";" "-" COMPUTE_UNIT "${ASCEND_COMPUTE_UNIT}") + add_custom_target(op_registry_pack + COMMAND ${ASCEND_PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/cmake/util/ascendc_pack_opregistry.py + --input-path=${CMAKE_SOURCE_DIR}/build_out/ + --copy-path=${CMAKE_SOURCE_DIR}/build_out/tmp/vendors/${vendor_name}/ + --output-path=${CMAKE_SOURCE_DIR}/build_out/library/ + --vendor-name=${vendor_name} + --compute-unit=${COMPUTE_UNIT} + --framework-type=${ASCEND_FRAMEWORK_TYPE} + --platform=${CMAKE_SYSTEM_PROCESSOR}) + add_library(ascend_opregistry INTERFACE) + target_link_libraries(ascend_opregistry INTERFACE opregistry) + target_link_directories(ascend_opregistry INTERFACE ${CMAKE_SOURCE_DIR}/build_out/library) + target_include_directories(ascend_opregistry INTERFACE ${CMAKE_SOURCE_DIR}/build_out/library) + add_dependencies(ascend_opregistry op_registry_pack) + if(EXISTS "${CMAKE_SOURCE_DIR}/framework/caffe_plugin") + add_dependencies(op_registry_pack cust_caffe_parsers) + elseif(EXISTS "${CMAKE_SOURCE_DIR}/framework/tf_plugin") + add_dependencies(op_registry_pack cust_tf_parsers) + elseif(EXISTS "${CMAKE_SOURCE_DIR}/framework/onnx_plugin") + add_dependencies(op_registry_pack cust_onnx_parsers) + endif() +endfunction() + +function(add_kernels_install) + # install kernel file + if (${ENABLE_SOURCE_PACKAGE}) + install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/binary/dynamic/ + DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/${vendor_name}_impl/dynamic/ + ) + endif() + + # install *.o files and *.json files + if (${ENABLE_BINARY_PACKAGE}) + set(INSTALL_DIR packages/vendors/${vendor_name}/op_impl/ai_core/tbe/) + foreach(compute_unit ${ASCEND_COMPUTE_UNIT}) + install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/binary/${compute_unit}/ + DESTINATION ${INSTALL_DIR}/kernel/${compute_unit}/ + ) + endforeach() + install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/binary/config/ + DESTINATION ${INSTALL_DIR}/kernel/config/ + ) + endif() +endfunction() + +function(add_kernels_compile) + set(DYNAMIC_PATH "") + if (${ENABLE_SOURCE_PACKAGE}) + set(DYNAMIC_PATH ${CMAKE_CURRENT_BINARY_DIR}/binary/dynamic) + file(MAKE_DIRECTORY ${DYNAMIC_PATH}) + file(GLOB KERNEL_FILES "${CMAKE_SOURCE_DIR}/op_kernel/*") + file(COPY ${KERNEL_FILES} DESTINATION ${DYNAMIC_PATH}) + file(REMOVE "${DYNAMIC_PATH}/CMakeLists.txt") + endif() + + foreach(compute_unit ${ASCEND_COMPUTE_UNIT}) + # generate aic-${compute_unit}-ops-info.json + add_ops_info_target(TARGET ops_info_gen_${compute_unit} + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/tbe/op_info_cfg/ai_core/${compute_unit}/aic-${compute_unit}-ops-info.json + OPS_INFO ${ASCEND_AUTOGEN_PATH}/aic-${compute_unit}-ops-info.ini + INSTALL_DIR packages/vendors/${vendor_name}/op_impl/ai_core/tbe/config/${compute_unit} + ) + + # define a target:binary to prevent kernel file from being rebuilt during the preinstall process + if (NOT TARGET binary) + add_custom_target(binary) + endif() + + if (${ENABLE_BINARY_PACKAGE} OR ${ENABLE_SOURCE_PACKAGE}) + if (${ENABLE_BINARY_PACKAGE}) + # gen binary_info_config.json and .json + add_custom_target(ascendc_bin_${compute_unit}_gen_ops_config + COMMAND ${ASCEND_PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/cmake/util/insert_simplified_keys.py + -p ${CMAKE_CURRENT_BINARY_DIR}/binary/${compute_unit} + COMMAND ${ASCEND_PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/cmake/util/ascendc_ops_config.py + -p ${CMAKE_CURRENT_BINARY_DIR}/binary/${compute_unit} + -s ${compute_unit} + COMMAND ${CMAKE_COMMAND} -E make_directory + ${CMAKE_CURRENT_BINARY_DIR}/binary/config/${compute_unit} + COMMAND mv ${CMAKE_CURRENT_BINARY_DIR}/binary/${compute_unit}/*.json + ${CMAKE_CURRENT_BINARY_DIR}/binary/config/${compute_unit} + ) + else() + if (NOT TARGET ascendc_bin_${compute_unit}_gen_ops_config) + add_custom_target(ascendc_bin_${compute_unit}_gen_ops_config) + endif() + endif() + add_dependencies(binary ascendc_bin_${compute_unit}_gen_ops_config) + + # get op_type-op_name from aic-${compute_unit}-ops-info.ini + execute_process(COMMAND ${ASCEND_PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/cmake/util/ascendc_get_op_name.py + --ini-file=${ASCEND_AUTOGEN_PATH}/aic-${compute_unit}-ops-info.ini + OUTPUT_VARIABLE OP_TYPE_NAME + RESULT_VARIABLE EXEC_RESULT + ERROR_VARIABLE EXEC_ERROR + ) + if (${EXEC_RESULT}) + message(FATAL_ERROR, "get op name failed, gen error: ${EXEC_ERROR}") + endif() + + # compile op one by one with ascendc_compile_kernel.py + string(REPLACE "\n" ";" TYPE_NAME_LIST "${OP_TYPE_NAME}") + foreach(TYPE_NAME IN LISTS TYPE_NAME_LIST) + if (NOT "${TYPE_NAME}" STREQUAL "") + string(REPLACE "-" ";" bin_sep ${TYPE_NAME}) + list(GET bin_sep 0 op_type) + list(GET bin_sep 1 op_file) + add_simple_kernel_compile(OP_TYPE ${op_type} + SRC ${CMAKE_SOURCE_DIR}/op_kernel/${op_file}.cpp + COMPUTE_UNIT ${compute_unit} + JSON_FILE ${CMAKE_CURRENT_BINARY_DIR}/tbe/op_info_cfg/ai_core/${compute_unit}/aic-${compute_unit}-ops-info.json + DYNAMIC_PATH ${DYNAMIC_PATH}) + endif() + endforeach() + endif() + endforeach() + + # generate npu_supported_ops.json + add_npu_support_target(TARGET npu_supported_ops + OPS_INFO_DIR ${ASCEND_AUTOGEN_PATH} + OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/tbe/op_info_cfg/ai_core + INSTALL_DIR packages/vendors/${vendor_name}/framework/${ASCEND_FRAMEWORK_TYPE} + ) + + if(ENABLE_TEST AND EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/testcases) + add_subdirectory(testcases) + endif() + + if(NOT ASCEND_PACK_SHARED_LIBRARY) + add_kernels_install() + else() + add_opregistry_target() + endif() +endfunction() diff --git a/csrc/deepep/ops2/cmake/intf.cmake b/csrc/deepep/ops2/cmake/intf.cmake new file mode 100755 index 00000000..d2643bbc --- /dev/null +++ b/csrc/deepep/ops2/cmake/intf.cmake @@ -0,0 +1,28 @@ + +add_library(intf_pub INTERFACE) +target_compile_options(intf_pub INTERFACE + -fPIC + -fvisibility=hidden + -fvisibility-inlines-hidden + $<$:-O2> + $<$:-O0 -g> + $<$:-std=c++11> + $<$,$>:-ftrapv -fstack-check> + $<$:-pthread -Wfloat-equal -Wshadow -Wformat=2 -Wno-deprecated -Wextra> + $,-fstack-protector-strong,-fstack-protector-all> +) +target_compile_definitions(intf_pub INTERFACE + _GLIBCXX_USE_CXX11_ABI=0 + $<$:_FORTIFY_SOURCE=2> +) +target_include_directories(intf_pub INTERFACE ${ASCEND_CANN_PACKAGE_PATH}/include + ${CMAKE_CURRENT_SOURCE_DIR}/op_kernel +) +target_link_options(intf_pub INTERFACE + $<$,EXECUTABLE>:-pie> + $<$:-s> + -Wl,-z,relro + -Wl,-z,now + -Wl,-z,noexecstack +) +target_link_directories(intf_pub INTERFACE ${ASCEND_CANN_PACKAGE_PATH}/lib64) diff --git a/csrc/deepep/ops2/cmake/makeself.cmake b/csrc/deepep/ops2/cmake/makeself.cmake new file mode 100755 index 00000000..82c8da7e --- /dev/null +++ b/csrc/deepep/ops2/cmake/makeself.cmake @@ -0,0 +1,31 @@ +execute_process(COMMAND bash ${CMAKE_CURRENT_LIST_DIR}/util/makeself/makeself.sh + --header ${CMAKE_CURRENT_LIST_DIR}/util/makeself/makeself-header.sh + --help-header ./help.info + --gzip --complevel 4 --nomd5 --sha256 + ./ ${CPACK_PACKAGE_FILE_NAME} "version:1.0" ./install.sh + WORKING_DIRECTORY ${CPACK_TEMPORARY_DIRECTORY} + RESULT_VARIABLE EXEC_RESULT + ERROR_VARIABLE EXEC_ERROR +) + +if (NOT "${EXEC_RESULT}x" STREQUAL "0x") + message(FATAL_ERROR "CPack Command error: ${EXEC_RESULT}\n${EXEC_ERROR}") +endif() + +execute_process(COMMAND cp ${CPACK_EXTERNAL_BUILT_PACKAGES} ${CPACK_PACKAGE_DIRECTORY}/ + COMMAND echo "Copy ${CPACK_EXTERNAL_BUILT_PACKAGES} to ${CPACK_PACKAGE_DIRECTORY}/" + WORKING_DIRECTORY ${CPACK_TEMPORARY_DIRECTORY} + ) + +if (NOT "${CPACK_PACKAGE_DIRECTORY}x" STREQUAL "${CPACK_INSTALL_PREFIX}x") + execute_process( + COMMAND ${CMAKE_COMMAND} -E make_directory ${CPACK_INSTALL_PREFIX} + WORKING_DIRECTORY ${CPACK_TEMPORARY_DIRECTORY} + ) + + execute_process( + COMMAND cp ${CPACK_EXTERNAL_BUILT_PACKAGES} ${CPACK_INSTALL_PREFIX}/ + COMMAND echo "Copy ${CPACK_EXTERNAL_BUILT_PACKAGES} to ${CPACK_INSTALL_PREFIX}/" + WORKING_DIRECTORY ${CPACK_TEMPORARY_DIRECTORY} + ) +endif() diff --git a/csrc/deepep/ops2/cmake/util/__init__.py b/csrc/deepep/ops2/cmake/util/__init__.py new file mode 100755 index 00000000..364083fa --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import sys + +PYF_PATH = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(PYF_PATH) diff --git a/csrc/deepep/ops2/cmake/util/ascendc_bin_param_build.py b/csrc/deepep/ops2/cmake/util/ascendc_bin_param_build.py new file mode 100755 index 00000000..388575cd --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/ascendc_bin_param_build.py @@ -0,0 +1,531 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import argparse +import copy +import hashlib +import json +import os +import re +import sys +from collections import defaultdict +from typing import Dict, List, NamedTuple, Set, Tuple + +import const_var +import opdesc_parser + +PYF_PATH = os.path.dirname(os.path.realpath(__file__)) + + +class ParamInfo(NamedTuple): + dtype_list: list + format_list: list + dtype_for_bin_list: dict + format_for_bin_list: dict + + +class BinParamBuilder(opdesc_parser.OpDesc): + def __init__(self: any, op_type: str): + super().__init__(op_type) + self.soc = "" + self.out_path = "" + self.tiling_keys = set() + self.op_debug_config = "" + + def set_soc_version(self: any, soc: str): + self.soc = soc + + def set_out_path(self: any, out_path: str): + self.out_path = out_path + + def set_tiling_key(self: any, tiling_key_info: Set): + if tiling_key_info: + self.tiling_keys.update(tiling_key_info) + + def set_op_debug_config(self: any, op_debug_config: str): + if op_debug_config: + self.op_debug_config = op_debug_config + + def get_full_list(self: any): + dtype_list = [] + for dtype_in in self.input_dtype: + dtype_list.append(dtype_in.split(",")) + for dtype_out in self.output_dtype: + dtype_list.append(dtype_out.split(",")) + + format_list = [] + for fmt_in in self.input_fmt: + format_list.append(fmt_in.split(",")) + for fmt_out in self.output_fmt: + format_list.append(fmt_out.split(",")) + + dtype_for_bin_list = [ + [] for _ in range(len(self.input_dtype) + len(self.output_dtype)) + ] + format_for_bin_list = copy.deepcopy(dtype_for_bin_list) + + for key, value in self.input_dtype_for_bin.items(): + dtype_for_bin_list[key] = value.split(",") + for key, value in self.output_dtype_for_bin.items(): + dtype_for_bin_list[key + len(self.input_dtype)] = value.split(",") + for key, value in self.input_fmt_for_bin.items(): + format_for_bin_list[key] = value.split(",") + for key, value in self.output_fmt_for_bin.items(): + format_for_bin_list[key + len(self.input_dtype)] = value.split(",") + + return ParamInfo( + dtype_list, format_list, dtype_for_bin_list, format_for_bin_list + ) + + def gen_bin_cprs_list(self: any, param_info: ParamInfo): + combine_dict = {} + origin_combine_dict = {} + for cob_idx in range(0, len(self.input_dtype[0].split(","))): + origin_combine = "" + combine = "" + for param_idx in range(0, len(self.input_dtype) + len(self.output_dtype)): + if param_info.dtype_for_bin_list[param_idx]: + combine += param_info.dtype_for_bin_list[param_idx][cob_idx] + else: + combine += param_info.dtype_list[param_idx][cob_idx] + origin_combine += param_info.dtype_list[param_idx][cob_idx] + if param_info.format_for_bin_list[param_idx]: + combine += param_info.format_for_bin_list[param_idx][cob_idx] + else: + combine += param_info.format_list[param_idx][cob_idx] + origin_combine += param_info.format_list[param_idx][cob_idx] + if combine not in combine_dict: + combine_dict[combine] = [] + combine_dict[combine].append(cob_idx) + origin_combine_dict[origin_combine] = cob_idx + for key, value in combine_dict.items(): + if key not in origin_combine_dict: + print(f"WARNING: ForBinQuery {key} not in origin combine") + self.bin_save_list += value + continue + if len(value) == 1 and value[0] == origin_combine_dict[key]: + self.bin_save_list += value + continue + self.bin_cprs_head.append(origin_combine_dict[key]) + self.bin_cprs_list.append(value) + for index, sub_list in enumerate(self.bin_cprs_list): + if self.bin_cprs_head[index] not in self.bin_save_list: + continue + sub_list.append(self.bin_cprs_head[index]) + self.bin_save_list += self.bin_cprs_head + + def gen_for_bin_list(self: any, param_info: ParamInfo): + combine_size = len(self.input_dtype[0].split(",")) + input_size = len(self.input_dtype) + output_size = len(self.output_dtype) + + self.input_dtype_for_bin_list = [[] for _ in range(input_size)] + self.output_dtype_for_bin_list = [[] for _ in range(output_size)] + for i in range(0, input_size): + self.input_dtype_for_bin_list[i] = [[] for _ in range(combine_size)] + for i in range(0, output_size): + self.output_dtype_for_bin_list[i] = [[] for _ in range(combine_size)] + self.input_fmt_for_bin_list = copy.deepcopy(self.input_dtype_for_bin_list) + self.output_fmt_for_bin_list = copy.deepcopy(self.output_dtype_for_bin_list) + + for index, sub_list in enumerate(self.bin_cprs_list): + head_idx = self.bin_cprs_head[index] + for cmb_idx in sub_list: + for i in range(0, input_size): + self.input_dtype_for_bin_list[i][head_idx].append( + param_info.dtype_list[i][cmb_idx] + ) + self.input_fmt_for_bin_list[i][head_idx].append( + param_info.format_list[i][cmb_idx] + ) + for i in range(0, output_size): + self.output_dtype_for_bin_list[i][head_idx].append( + param_info.dtype_list[i + input_size][cmb_idx] + ) + self.output_fmt_for_bin_list[i][head_idx].append( + param_info.format_list[i + input_size][cmb_idx] + ) + + def rm_cprs_cmb(self: any, dtype_list, format_list, input_size, output_size): + for i in range(0, input_size): + self.input_dtype_for_bin_list[i] = [ + element + for index, element in enumerate(self.input_dtype_for_bin_list[i]) + if index in self.bin_save_list + ] + self.input_fmt_for_bin_list[i] = [ + element + for index, element in enumerate(self.input_fmt_for_bin_list[i]) + if index in self.bin_save_list + ] + new_dtype_list = [ + element + for index, element in enumerate(dtype_list[i]) + if index in self.bin_save_list + ] + new_dtype_str = "" + for dtype in new_dtype_list: + new_dtype_str += f"{dtype}," + self.input_dtype[i] = new_dtype_str[:-1] + new_format_list = [ + element + for index, element in enumerate(format_list[i]) + if index in self.bin_save_list + ] + new_format_str = "" + for fmt in new_format_list: + new_format_str += f"{fmt}," + self.input_fmt[i] = new_format_str[:-1] + for i in range(0, output_size): + self.output_dtype_for_bin_list[i] = [ + element + for index, element in enumerate(self.output_dtype_for_bin_list[i]) + if index in self.bin_save_list + ] + self.output_fmt_for_bin_list[i] = [ + element + for index, element in enumerate(self.output_fmt_for_bin_list[i]) + if index in self.bin_save_list + ] + new_dtype_list = [ + element + for index, element in enumerate(dtype_list[i + input_size]) + if index in self.bin_save_list + ] + new_dtype_str = "" + for dtype in new_dtype_list: + new_dtype_str += f"{dtype}," + self.output_dtype[i] = new_dtype_str[:-1] + new_format_list = [ + element + for index, element in enumerate(format_list[i + input_size]) + if index in self.bin_save_list + ] + new_format_str = "" + for fmt in new_format_list: + new_format_str += f"{fmt}," + self.output_fmt[i] = new_format_str[:-1] + + def is_set_for_bin_query(self: any): + return any( + [ + self.input_dtype_for_bin, + self.output_dtype_for_bin, + self.input_fmt_for_bin, + self.output_fmt_for_bin, + ] + ) + + def for_bin_list_match(self: any): + if not self.is_set_for_bin_query(): + return + input_size = len(self.input_dtype) + output_size = len(self.output_dtype) + param_info = self.get_full_list() + self.gen_bin_cprs_list(param_info) + self.gen_for_bin_list(param_info) + if len(self.bin_save_list) == len(self.input_dtype[0].split(",")): + print( + f"WARNING: ForBinQuery can not compress number of bin file with this set, please check!!." + ) + return + self.rm_cprs_cmb( + param_info.dtype_list, param_info.format_list, input_size, output_size + ) + + def gen_input_json(self: any, auto_gen_path: str): + key_map = {} + self.for_bin_list_match() + count = len(self.input_dtype[0].split(",")) + required_parameters = set() + index_value = -1 + + for i in range(0, count): + inputs = [] + outputs = [] + attrs = [] + required_parameter = [] + op_node = {} + + for idx in range(0, len(self.input_name)): + idtypes = self.input_dtype[idx].split(",") + ifmts = self.input_fmt[idx].split(",") + itype = self.input_type[idx] + para = {} + para["name"] = self.input_name[idx][:-5] + para["index"] = idx + para["dtype"] = idtypes[i] + if ( + self.is_set_for_bin_query() + and self.input_dtype_for_bin_list[idx][i] + ): + para["dtypeForBinQuery"] = self.input_dtype_for_bin_list[idx][i] + para["format"] = ifmts[i] + if self.is_set_for_bin_query() and self.input_fmt_for_bin_list[idx][i]: + para["formatForBinQuery"] = self.input_fmt_for_bin_list[idx][i] + para["paramType"] = itype + para["shape"] = [-2] + para["format_match_mode"] = "FormatAgnostic" + + input_parameter_key = (idtypes[i], ifmts[i]) + if itype == "dynamic": + inputs.append([para]) + required_parameter.append(input_parameter_key) + elif itype == "required": + inputs.append(para) + required_parameter.append(input_parameter_key) + else: + inputs.append(para) + + for idx in range(0, len(self.output_name)): + odtypes = self.output_dtype[idx].split(",") + ofmts = self.output_fmt[idx].split(",") + otype = self.output_type[idx] + para = {} + para["name"] = self.output_name[idx][:-5] + para["index"] = idx + para["dtype"] = odtypes[i] + if ( + self.is_set_for_bin_query() + and self.output_dtype_for_bin_list[idx][i] + ): + para["dtypeForBinQuery"] = self.output_dtype_for_bin_list[idx][i] + para["format"] = ofmts[i] + if self.is_set_for_bin_query() and self.output_fmt_for_bin_list[idx][i]: + para["formatForBinQuery"] = self.output_fmt_for_bin_list[idx][i] + para["paramType"] = otype + para["shape"] = [-2] + para["format_match_mode"] = "FormatAgnostic" + output_parameter_key = (odtypes[i], ofmts[i]) + if otype == "dynamic": + outputs.append([para]) + required_parameter.append(output_parameter_key) + elif otype == "required": + outputs.append(para) + required_parameter.append(output_parameter_key) + else: + outputs.append(para) + + for attr in self.attr_list: + att = {} + att["name"] = attr + atype = self.attr_val.get(attr).get("type").lower() + att["dtype"] = atype + att["value"] = const_var.ATTR_DEF_VAL.get(atype) + attrs.append(att) + + required_parameter_tuple = tuple(required_parameter) + if required_parameter_tuple in required_parameters: + continue + else: + required_parameters.add(required_parameter_tuple) + index_value += 1 + + op_node["bin_filename"] = "" + op_node["inputs"] = inputs + op_node["outputs"] = outputs + if len(attrs) > 0: + op_node["attrs"] = attrs + + param = {} + param["op_type"] = self.op_type + param["op_list"] = [op_node] + objstr = json.dumps(param, indent=" ") + md5sum = hashlib.md5(objstr.encode("utf-8")).hexdigest() + while key_map.get(md5sum) is not None: + objstr += "1" + md5sum = hashlib.md5(objstr.encode("utf-8")).hexdigest() + key_map[md5sum] = md5sum + bin_file = self.op_type + "_" + md5sum + op_node["bin_filename"] = bin_file + param_file = os.path.join(self.out_path, bin_file + "_param.json") + param_file = os.path.realpath(param_file) + with os.fdopen( + os.open(param_file, const_var.WFLAGS, const_var.WMODES), "w" + ) as fd: + json.dump(param, fd, indent=" ") + self._write_build_cmd(param_file, bin_file, index_value, auto_gen_path) + + def _write_build_cmd( + self: any, param_file: str, bin_file: str, index: int, auto_gen_path: str + ): + hard_soc = const_var.conv_soc_ver(self.soc) + if not hard_soc: + hard_soc = self.soc.capitalize() + name_com = [self.op_type, self.op_file, str(index)] + compile_file = os.path.join(self.out_path, "-".join(name_com) + ".sh") + compile_file = os.path.realpath(compile_file) + + bin_cmd_str = "res=$(opc $1 --main_func={fun} --input_param={param} --soc_version={soc} \ + --output=$2 --impl_mode={impl} --simplified_key_mode=0 --op_mode=dynamic " + + build_cmd_var = "#!/bin/bash\n" + build_cmd_var += f'echo "[{self.soc}] Generating {bin_file} ..."\n' + plog_level = os.environ.get("ASCEND_GLOBAL_LOG_LEVEL") + plog_stdout = os.environ.get("ASCEND_SLOG_PRINT_TO_STDOUT") + if plog_level is None: + build_cmd_var += const_var.SET_PLOG_LEVEL_ERROR + if plog_stdout is None: + build_cmd_var += const_var.SET_PLOG_STDOUT + build_cmd_var += const_var.SRC_ENV + if hard_soc == "Ascend610Lite": + build_cmd_var += f"export ASCEND_CUSTOM_OPP_PATH={auto_gen_path}:$ASCEND_CUSTOM_OPP_PATH \n" + build_cmd_var += bin_cmd_str.format( + fun=self.op_intf, + soc=hard_soc, + param=param_file, + impl="high_performance,optional", + ) + enable_tiling_keys = False + if self.tiling_keys: + tiling_keys_list = sorted(list(self.tiling_keys)) + tiling_key_str = ",".join([str(_key) for _key in tiling_keys_list]) + build_cmd_var += f' --tiling_key="{tiling_key_str}"' + enable_tiling_keys = True + + if self.op_debug_config: + op_debug_str = ",".join([str(_key) for _key in list(self.op_debug_config)]) + build_cmd_var += f" --op_debug_config={op_debug_str}" + + build_cmd_var += ")\n" + build_cmd_var += "\n" + if enable_tiling_keys is False: + build_cmd_var += 'echo "${res}"\n' + build_cmd_var += const_var.CHK_CMD.format(res_file=bin_file + ".json") + build_cmd_var += const_var.CHK_CMD.format(res_file=bin_file + ".o") + else: + build_cmd_var += "if [ $? -eq 1 ]; then\n" + build_cmd_var += ' if echo "${res}" | \ +grep -q "None of the given tiling keys are in the supported list"; then\n' + build_cmd_var += ' echo "${res}"\n' + build_cmd_var += " else\n" + build_cmd_var += ' echo "${res}"\n' + build_cmd_var += " exit 1\n" + build_cmd_var += " fi\n" + build_cmd_var += "else\n" + build_cmd_var += 'echo "${res}"\n' + build_cmd_var += const_var.CHK_CMD.format(res_file=bin_file + ".json") + build_cmd_var += const_var.CHK_CMD.format(res_file=bin_file + ".o") + build_cmd_var += "fi\n" + build_cmd_var += f'echo "[{self.soc}] Generating {bin_file} Done"\n' + + with os.fdopen( + os.open(compile_file, const_var.WFLAGS, const_var.WMODES), "w" + ) as fd: + fd.write(build_cmd_var) + + +def get_tiling_keys(tiling_keys: str) -> Set: + all_tiling_keys = set() + if not tiling_keys: + return all_tiling_keys + + tiling_key_list = tiling_keys.split(";") + for tiling_key_value in tiling_key_list: + pattern = r"(? int(end): + continue + for i in range(int(start), int(end) + 1): + all_tiling_keys.add(i) + elif tiling_key_value.isdigit(): + all_tiling_keys.add(int(tiling_key_value)) + return all_tiling_keys + + +def trans_soc_verion(soc_ver: str): + low_soc_ver = soc_ver.lower() + if low_soc_ver not in opdesc_parser.SOC_TO_SHORT_SOC_MAP: + return low_soc_ver + return opdesc_parser.SOC_TO_SHORT_SOC_MAP[low_soc_ver] + + +def parse_op_debug_confg(opc_config_file: str, soc: str) -> Dict: + tiling_key_info = defaultdict(set) + op_debug_config = defaultdict(set) + if not opc_config_file: + return tiling_key_info, op_debug_config + + if not os.path.exists(opc_config_file): + return tiling_key_info, op_debug_config + + with open(opc_config_file, "r") as file: + contents = file.readlines() + + for _content in contents: + content = _content.strip() + opc_configs = content.split("@") + if len(opc_configs) < 3: + continue + + op_type = opc_configs[0] + if not op_type: + continue + + compute_unit = opc_configs[1] + if compute_unit: + compute_unit_list = compute_unit.split(";") + soc_lists = [] + for soc_ver in compute_unit_list: + short_soc_ver = trans_soc_verion(soc_ver) + soc_lists.append(short_soc_ver) + if soc not in soc_lists: + continue + + for options in opc_configs[2:]: + if "--tiling_key" in options: + format_tiling_keys = get_tiling_keys(options.split("=")[1]) + if format_tiling_keys: + tiling_key_info[op_type].update(format_tiling_keys) + if "--op_debug_config" in options: + format_debug_config = set(options.split("=")[1].split(";")) + if format_debug_config: + op_debug_config[op_type].update(format_debug_config) + + return tiling_key_info, op_debug_config + + +def gen_bin_param_file( + cfgfile: str, out_dir: str, soc: str, opc_config_file: str = "", ops: list = None +): + if not os.path.exists(cfgfile): + print( + f"INFO: {cfgfile} does not exists in this project, skip generating compile commands." + ) + return + + op_descs = opdesc_parser.get_op_desc(cfgfile, [], [], BinParamBuilder, ops) + tiling_key_info, op_debug_config = parse_op_debug_confg(opc_config_file, soc) + auto_gen_path_dir = os.path.dirname(cfgfile) + all_soc_key = "ALL" + for op_desc in op_descs: + op_desc.set_soc_version(soc) + op_desc.set_out_path(out_dir) + if op_desc.op_type in op_debug_config: + op_desc.set_op_debug_config(op_debug_config[op_desc.op_type]) + if all_soc_key in op_debug_config: + op_desc.set_op_debug_config(op_debug_config[all_soc_key]) + if op_desc.op_type in tiling_key_info: + op_desc.set_tiling_key(tiling_key_info[op_desc.op_type]) + if all_soc_key in tiling_key_info: + op_desc.set_tiling_key(tiling_key_info[all_soc_key]) + op_desc.gen_input_json(auto_gen_path_dir) + + +def parse_args(argv): + """Command line parameter parsing""" + parser = argparse.ArgumentParser() + parser.add_argument("argv", nargs="+") + parser.add_argument("--opc-config-file", nargs="?", const="", default="") + return parser.parse_args(argv) + + +if __name__ == "__main__": + args = parse_args(sys.argv) + if len(args.argv) <= 3: + raise RuntimeError("arguments must greater than 3") + gen_bin_param_file( + args.argv[1], args.argv[2], args.argv[3], opc_config_file=args.opc_config_file + ) diff --git a/csrc/deepep/ops2/cmake/util/ascendc_compile_kernel.py b/csrc/deepep/ops2/cmake/util/ascendc_compile_kernel.py new file mode 100755 index 00000000..c4c81576 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/ascendc_compile_kernel.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import argparse +import glob +import os +import shutil +import subprocess +import sys +import time + +import ascendc_bin_param_build +import ascendc_impl_build +import ascendc_op_info +import const_var + + +class CompileKernel: + def __init__(self: any, args: any): + self.op_type = args.op_name + self.op_cpp_file = os.path.realpath(args.src_file) + self.op_soc_ver = args.compute_unit + self.compile_options = args.compile_options + self.op_debug_config = args.debug_config + self.op_cfg_ini = os.path.realpath(args.config_ini) + self.op_tiling = os.path.realpath(args.tiling_lib) + self.op_output = os.path.realpath(args.output_path) + self.op_impl_py = None + self.compile_sh = [] + self.working_dir = os.path.join( + os.getcwd(), + self.op_type + "_" + self.op_soc_ver, + ) + self.build_opp_path = os.path.join(self.working_dir, "customize") + os.makedirs(self.working_dir) + os.makedirs(self.op_output, exist_ok=True) + if args.dynamic_dir is not None and args.dynamic_dir != "": + self.dynamic_dir = os.path.realpath(args.dynamic_dir) + else: + self.dynamic_dir = None + if args.json_file is not None and args.json_file != "": + self.json_file = args.json_file + else: + self.json_file = None + + def clean(self: any): + if "dump_cce" not in self.op_debug_config: + shutil.rmtree(self.working_dir) + return + + def ascendc_gen_impl(self: any): + rep_cfg = {} + rep_cfg[const_var.REPLAY_BATCH] = "" + rep_cfg[const_var.REPLAY_ITERATE] = "" + cfg_dir = {} + cfg_dir[const_var.CFG_IMPL_DIR] = os.path.dirname(self.op_cpp_file) + cfg_dir[const_var.CFG_OUT_DIR] = os.path.join(self.working_dir, "dynamic") + os.makedirs(os.path.join(self.working_dir, "dynamic"), exist_ok=True) + cfg_dir[const_var.AUTO_GEN_DIR] = os.path.dirname(self.op_cfg_ini) + ascendc_impl_build.write_scripts( + self.op_cfg_ini, rep_cfg, cfg_dir, [self.op_type], self.compile_options + ) + py_files = glob.glob(os.path.join(self.working_dir, "dynamic", "*.py")) + if py_files is None or len(py_files) != 1: + self.clean() + raise RuntimeError("compile py file {} generated error!".format(py_files)) + self.op_impl_py = os.path.join( + self.working_dir, "dynamic", self.op_type + ".py" + ) + if self.dynamic_dir is not None: + shutil.copy(py_files[0], self.dynamic_dir) + os.rename(py_files[0], self.op_impl_py) + if not os.path.exists(self.op_impl_py): + self.clean() + raise RuntimeError( + "compile py file {} not generated!".format(self.op_impl_py) + ) + + def ascendc_gen_param(self: any): + bin_param_path = os.path.join(self.working_dir, "bin_param") + os.makedirs(bin_param_path) + base_dir = os.path.dirname(self.op_cfg_ini) + opc_config_file = os.path.join(base_dir, "custom_opc_options.ini") + ascendc_bin_param_build.gen_bin_param_file( + self.op_cfg_ini, + bin_param_path, + self.op_soc_ver, + opc_config_file, + [self.op_type], + ) + tiling_key_info, op_debug_config = ascendc_bin_param_build.parse_op_debug_confg( + opc_config_file, self.op_type + ) + if self.op_type in op_debug_config: + self.op_debug_config = op_debug_config[self.op_type] + if "ALL" in op_debug_config: + self.op_debug_config = op_debug_config["ALL"] + bin_param_files = glob.glob(os.path.join(bin_param_path, "*.json")) + if bin_param_files is None or len(bin_param_files) <= 0: + self.clean() + raise RuntimeError("compile binary param json file not generated!") + self.compile_sh = glob.glob(os.path.join(bin_param_path, "*.sh")) + if self.compile_sh is None or len(self.compile_sh) != len(bin_param_files): + self.clean() + raise RuntimeError("compile binary shell file not generated!") + + def ascendc_put_tiling(self: any): + tiling_path = os.path.join( + self.build_opp_path, "op_impl", "ai_core", "tbe", "op_tiling" + ) + os.makedirs(tiling_path) + tiling_so = os.path.join(tiling_path, "liboptiling.so") + os.symlink(self.op_tiling, tiling_so) + if not os.path.exists(tiling_so): + self.clean() + raise RuntimeError("prepare tiling lib {} link failed!".format(tiling_so)) + + def ascendc_put_json(self: any): + if self.json_file is not None: + json_file_dir = os.path.join( + self.build_opp_path, + "op_impl", + "ai_core", + "tbe", + "config", + self.op_soc_ver, + ) + os.makedirs(json_file_dir) + shutil.copy(self.json_file, json_file_dir) + build_json_file = os.path.join( + json_file_dir, "aic-{}-ops-info.json".format(self.op_soc_ver) + ) + if not os.path.exists(build_json_file): + self.clean() + raise RuntimeError( + "prepare json file aic-{}-ops-info.json failed!".format( + self.op_soc_ver + ) + ) + + def ascendc_build(self: any): + op_info = ascendc_op_info.OpInfo(self.op_type, self.op_cfg_ini) + op_file = op_info.get_op_file() + op_bin_dir = os.path.join(self.op_output, self.op_soc_ver, op_file) + os.makedirs(op_bin_dir, exist_ok=True) + all_tar = [] + sub_cmd = [] + index = 0 + for sh in self.compile_sh: + tar = op_file + str(index) + build_path = os.path.join(self.working_dir, "kernel_" + str(index)) + os.makedirs(build_path) + all_tar.append(tar) + sub_cmd.append(tar + ":") + sub_cmd.append( + "\tcd {} && bash {} --kernel-src=$(CPP) $(PY) $(OUT) $(MAKE)".format( + build_path, sh + ) + ) + index += 1 + mkfile = os.path.join(self.working_dir, op_file + ".make") + with os.fdopen(os.open(mkfile, const_var.WFLAGS, const_var.WMODES), "w") as fd: + sub_cmd.insert(0, "all: " + " ".join(all_tar)) + fd.write("\n".join(sub_cmd)) + + if os.getenv("TILINGKEY_PAR_COMPILE") is None: + cmd_str = ( + "export HI_PYTHON=python3 && export ASCEND_CUSTOM_OPP_PATH={} && export TILINGKEY_PAR_COMPILE=1" + "&& make -f {} PY={} OUT={} CPP={}" + ) + else: + cmd_str = "export HI_PYTHON=python3 && export ASCEND_CUSTOM_OPP_PATH={} && make -f {} PY={} OUT={} CPP={}" + + if ( + os.system( + cmd_str.format( + self.build_opp_path, + mkfile, + self.op_impl_py, + op_bin_dir, + self.op_cpp_file, + ) + ) + != 0 + ): + raise RuntimeError( + "Kernel Compilation Error: OpType {} Kernel File {}!".format( + self.op_type, self.op_cpp_file + ) + ) + + +def args_parse(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-n", "--op-name", nargs="?", help="Op name(Camel string) to compile." + ) + parser.add_argument("-s", "--src-file", nargs="?", help="Op kernel source file.") + + parser.add_argument("-u", "--compute-unit", nargs="?", help="Compute unit.") + parser.add_argument( + "-c", "--compile-options", nargs="?", help="Compile options of compiler." + ) + parser.add_argument( + "-d", + "--debug-config", + nargs="?", + help="Debug config of op, ref opc op-debug-config.", + ) + parser.add_argument("-i", "--config-ini", nargs="?", help="Op config ini file.") + parser.add_argument( + "-t", "--tiling-lib", nargs="?", help="Tiling shared library file." + ) + + parser.add_argument( + "-o", "--output-path", nargs="?", help="Output path of compile result." + ) + parser.add_argument( + "-dy", + "--dynamic-dir", + nargs="?", + default=None, + help="dynamic path of source compile.", + ) + parser.add_argument( + "-eb", + "--enable-binary", + nargs="?", + default=None, + help="whether binary compile is enabled.", + ) + parser.add_argument( + "-j", + "--json-file", + nargs="?", + default=None, + help="aic--ops-info.json file path.", + ) + # $(MAKE) is necessary for parallel compiling + parser.add_argument( + "-b", "--build-tool", nargs="?", default=None, help="build tool must be make." + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = args_parse() + kernel_builder = CompileKernel(args) + kernel_builder.clean() + if args.enable_binary == "False": + kernel_builder.ascendc_gen_impl() + kernel_builder.clean() + else: + kernel_builder.ascendc_gen_impl() + kernel_builder.ascendc_gen_param() + kernel_builder.ascendc_put_json() + kernel_builder.ascendc_put_tiling() + kernel_builder.ascendc_build() + kernel_builder.clean() diff --git a/csrc/deepep/ops2/cmake/util/ascendc_gen_options.py b/csrc/deepep/ops2/cmake/util/ascendc_gen_options.py new file mode 100755 index 00000000..637ae9d7 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/ascendc_gen_options.py @@ -0,0 +1,83 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +import json +import os +import re +import stat +import sys + +import const_var + + +def write_options_to_file( + file_name: str, options_str: str, op_type: str, compute_unit: str, split_char: str +): + flags = os.O_WRONLY | os.O_CREAT + modes = stat.S_IWUSR | stat.S_IRUSR + try: + with os.fdopen(os.open(file_name, flags, modes), "a") as fd: + fd.write( + op_type + split_char + compute_unit + split_char + options_str + "\n" + ) + except Exception as err: + print("write compile options config file failed") + raise (err) + + +def gen_compile_options( + compile_options_file: str, op_type: str, compute_unit: str, compile_options: list +): + base_dir = os.path.dirname(compile_options_file) + opc_config_file = os.path.join(base_dir, "custom_opc_options.ini") + compile_opt = [] + opc_debug_config = [] + opc_tiling_keys = "" + for opts in compile_options: + if "oom" in opts: + if opts == "--oom": + opc_debug_config.append("oom") + else: + raise RuntimeError(f"Unknown oom option format {opts}") + elif "--save-temp-files" in opts: + opc_debug_config.append("dump_cce") + elif "--tiling_key" in opts: + keys = opts.strip().split("=")[1].split(",") + keys_str = ";".join([key for key in keys]) + opc_tiling_keys = keys_str + else: + compile_opt.append(opts) + if len(compile_opt) > 0: + options_str = ";".join([opt for opt in compile_opt]) + write_options_to_file( + compile_options_file, options_str, op_type, compute_unit, "," + ) + opc_config_str = "" + if opc_debug_config: + opc_config_str = "--op_debug_config=" + ";".join( + [opt for opt in opc_debug_config] + ) + if len(opc_tiling_keys) > 0: + if opc_config_str != "": + opc_config_str += "@" + opc_config_str += "--tiling_key=" + opc_tiling_keys + + if opc_config_str != "": + write_options_to_file( + opc_config_file, opc_config_str, op_type, compute_unit, "@" + ) + + +if __name__ == "__main__": + if len(sys.argv) < 4: + raise RuntimeError("arguments must greater than 4") + compute_soc = "" + comp_options = [] + for i in range(len(sys.argv) - 3): + if sys.argv[i + 3].upper().startswith("ASCEND"): + compute_soc += sys.argv[i + 3] + ";" + else: + comp_options.append(sys.argv[i + 3]) + if compute_soc != "": + compute_soc = compute_soc[0:-1] + gen_compile_options(sys.argv[1], sys.argv[2], compute_soc, comp_options) diff --git a/csrc/deepep/ops2/cmake/util/ascendc_get_op_name.py b/csrc/deepep/ops2/cmake/util/ascendc_get_op_name.py new file mode 100755 index 00000000..5da592b3 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/ascendc_get_op_name.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import argparse +import configparser + + +def args_parse(): + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--ini-file", help="op info ini.") + return parser.parse_args() + + +if __name__ == "__main__": + args = args_parse() + op_config = configparser.ConfigParser() + op_config.read(args.ini_file) + for section in op_config.sections(): + print(section, end="-") + print(op_config.get(section, "opFile.value"), end="\n") diff --git a/csrc/deepep/ops2/cmake/util/ascendc_impl_build.py b/csrc/deepep/ops2/cmake/util/ascendc_impl_build.py new file mode 100755 index 00000000..7486ad13 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/ascendc_impl_build.py @@ -0,0 +1,748 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import argparse +import datetime +import glob +import os +import re +import sys +from typing import List + +import const_var +import opdesc_parser + +PYF_PATH = os.path.dirname(os.path.realpath(__file__)) + +IMPL_HEAD = '''#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +Copyright (c) Huawei Technologies Co., Ltd. {}-{}. All rights reserved. +""" + +import os, sys +import ctypes +import json +import shutil +from tbe.common.platform import get_soc_spec +from tbe.common.utils import para_check +from tbe.tikcpp import compile_op, replay_op, check_op_cap, generalize_op_params, get_code_channel, OpInfo +from tbe.tikcpp.compile_op import CommonUtility, AscendCLogLevel +from tbe.common.buildcfg import get_default_build_config +from impl.util.platform_adapter import tbe_register +from tbe.common.buildcfg import get_current_build_config +PYF_PATH = os.path.dirname(os.path.realpath(__file__)) + +DTYPE_MAP = {{"float32": ["DT_FLOAT", "float"], + "float16": ["DT_FLOAT16", "half"], + "int8": ["DT_INT8", "int8_t"], + "int16": ["DT_INT16", "int16_t"], + "int32": ["DT_INT32", "int32_t"], + "int64": ["DT_INT64", "int64_t"], + "uint1": ["DT_UINT1", "uint8_t"], + "uint8": ["DT_UINT8", "uint8_t"], + "uint16": ["DT_UINT16", "uint16_t"], + "uint32": ["DT_UINT32", "uint32_t"], + "uint64": ["DT_UINT64", "uint64_t"], + "bool": ["DT_BOOL", "bool"], + "double": ["DT_DOUBLE", "double"], + "dual": ["DT_DUAL", "unknown"], + "dual_sub_int8": ["DT_DUAL_SUB_INT8", "unknown"], + "dual_sub_uint8": ["DT_DUAL_SUB_UINT8", "unknown"], + "string": ["DT_STRING", "unknown"], + "complex32": ["DT_COMPLEX32", "unknown"], + "complex64": ["DT_COMPLEX64", "unknown"], + "complex128": ["DT_COMPLEX128", "unknown"], + "qint8": ["DT_QINT8", "unknown"], + "qint16": ["DT_QINT16", "unknown"], + "qint32": ["DT_QINT32", "unknown"], + "quint8": ["DT_QUINT8", "unknown"], + "quint16": ["DT_QUINT16", "unknown"], + "resource": ["DT_RESOURCE", "unknown"], + "string_ref": ["DT_STRING_REF", "unknown"], + "int4": ["DT_INT4", "int4b_t"], + "bfloat16": ["DT_BF16", "bfloat16_t"]}} + +def add_dtype_fmt_option_single(x, x_n, is_ref: bool = False): + options = [] + x_fmt = x.get("format") + x_dtype = x.get("dtype") + x_n_in_kernel = x_n + '_REF' if is_ref else x_n + options.append("-DDTYPE_{{n}}={{t}}".format(n=x_n_in_kernel, t=DTYPE_MAP.get(x_dtype)[1])) + options.append("-DORIG_DTYPE_{{n}}={{ot}}".format(n=x_n_in_kernel, ot=DTYPE_MAP.get(x_dtype)[0])) + options.append("-DFORMAT_{{n}}=FORMAT_{{f}}".format(n=x_n_in_kernel, f=x_fmt)) + return options + +def get_dtype_fmt_options(__inputs__, __outputs__): + options = [] + input_names = {} + output_names = {} + unique_param_name_set = set() + for idx, x in enumerate(__inputs__): + if x is None: + continue + x_n = input_names[idx].upper() + unique_param_name_set.add(x_n) + options += add_dtype_fmt_option_single(x, x_n) + + for idx, x in enumerate(__outputs__): + if x is None: + continue + x_n = output_names[idx].upper() + if x_n in unique_param_name_set: + options += add_dtype_fmt_option_single(x, x_n, True) + else: + options += add_dtype_fmt_option_single(x, x_n) + return options + +def load_dso(so_path): + try: + ctypes.CDLL(so_path) + except OSError as error : + CommonUtility.print_compile_log("", error, AscendCLogLevel.LOG_ERROR) + raise RuntimeError("cannot open %s" %(so_path)) + else: + msg = "load so succ " + so_path + CommonUtility.print_compile_log("", msg, AscendCLogLevel.LOG_INFO) + +def get_shortsoc_compile_option(compile_option_list: list, shortsoc:str): + compile_options = [] + if shortsoc in compile_option_list: + compile_options.extend(compile_option_list[shortsoc]) + if '__ALLSOC__' in compile_option_list: + compile_options.extend(compile_option_list['__ALLSOC__']) + return compile_options + +def get_kernel_source(src_file, dir_snake, dir_ex): + src_ex = os.path.join(PYF_PATH, "..", "ascendc", dir_ex, src_file) + if os.path.exists(src_ex): + return src_ex + src = os.environ.get('BUILD_KERNEL_SRC') + if src and os.path.exists(src): + return src + src = os.path.join(PYF_PATH, "..", "ascendc", dir_snake, src_file) + if os.path.exists(src): + return src + src = os.path.join(PYF_PATH, src_file) + if os.path.exists(src): + return src + src = os.path.join(PYF_PATH, "..", "ascendc", dir_snake, dir_snake + ".cpp") + if os.path.exists(src): + return src + src = os.path.join(PYF_PATH, "..", "ascendc", dir_ex, dir_ex + ".cpp") + if os.path.exists(src): + return src + src = os.path.join(PYF_PATH, "..", "ascendc", os.path.splitext(src_file)[0], src_file) + if os.path.exists(src): + return src + return src_ex + +''' + +IMPL_API = """ +@tbe_register.register_operator("{}", trans_bool_to_s8=False) +@para_check.check_op_params({}) +def {}({}, kernel_name="{}", impl_mode=""): +{} + if get_current_build_config("enable_op_prebuild"): + return + __inputs__, __outputs__, __attrs__ = _build_args({}) + options = get_dtype_fmt_options(__inputs__, __outputs__) + options += ["-x", "cce"] + bisheng = os.environ.get('BISHENG_REAL_PATH') + if bisheng is None: + bisheng = shutil.which("bisheng") + if bisheng != None: + bisheng_path = os.path.dirname(bisheng) + tikcpp_path = os.path.realpath(os.path.join(bisheng_path, "..", "..", "tikcpp")) + else: + tikcpp_path = os.path.realpath("/usr/local/Ascend/latest/compiler/tikcpp") + options.append("-I" + tikcpp_path) + options.append("-I" + os.path.join(tikcpp_path, "..", "..", "include")) + options.append("-I" + os.path.join(tikcpp_path, "tikcfw")) + options.append("-I" + os.path.join(tikcpp_path, "tikcfw", "impl")) + options.append("-I" + os.path.join(tikcpp_path, "tikcfw", "interface")) + options.append("-I" + os.path.join(PYF_PATH, "..", "ascendc", "common")) + if impl_mode == "high_performance": + options.append("-DHIGH_PERFORMANCE=1") + elif impl_mode == "high_precision": + options.append("-DHIGH_PRECISION=1") + if get_current_build_config("enable_deterministic_mode") == 1: + options.append("-DDETERMINISTIC_MODE=1") + else: + options.append("-DDETERMINISTIC_MODE=0") + + custom_compile_options = {}, + custom_all_compile_options = {}, + soc_version = get_soc_spec("SOC_VERSION") + soc_short = get_soc_spec("SHORT_SOC_VERSION").lower() + custom_compile_options_soc = get_shortsoc_compile_option(custom_compile_options[0], soc_short) + custom_all_compile_options_soc = get_shortsoc_compile_option(custom_all_compile_options[0], soc_short) + options += custom_all_compile_options_soc + options += custom_compile_options_soc + + origin_func_name = "{}" + ascendc_src_dir_ex = "{}" + ascendc_src_dir = "{}" + ascendc_src_file = "{}" + src = get_kernel_source(ascendc_src_file, ascendc_src_dir, ascendc_src_dir_ex) +""" + +REPLAY_OP_API = """ + msg = "start replay Ascend C Operator {}, kernel name is {}" + CommonUtility.print_compile_log("", msg, AscendCLogLevel.LOG_INFO) + tikreplay_codegen_path = tikcpp_path + "/tikreplaylib/lib" + tikreplay_stub_path = tikcpp_path + "/tikreplaylib/lib/" + soc_version + msg = "start load libtikreplaylib_codegen.so and libtikreplaylib_stub.so" + CommonUtility.print_compile_log("", msg, AscendCLogLevel.LOG_INFO) + codegen_so_path = tikreplay_codegen_path + "/libtikreplaylib_codegen.so" + replaystub_so_path = tikreplay_stub_path + "/libtikreplaylib_stub.so" + if PYF_PATH.endswith("dynamic"): + op_replay_path = os.path.join(PYF_PATH, "..", "..", "op_replay") + else: + op_replay_path = os.path.join(PYF_PATH, "..", "op_replay") + replayapi_so_path = os.path.join(op_replay_path, "libreplay_{}_" + soc_short + ".so") + load_dso(codegen_so_path) + load_dso(replaystub_so_path) + load_dso(replayapi_so_path) + op_type = "{}" + entry_obj = os.path.join(op_replay_path, "{}_entry_" + soc_short + ".o") + code_channel = get_code_channel(src, kernel_name, op_type, options) + op_info = OpInfo(kernel_name = kernel_name, op_type = op_type, inputs = __inputs__, outputs = __outputs__,\\ + attrs = __attrs__, impl_mode = impl_mode, param_type_dynamic = {}) + res, msg = replay_op(op_info, entry_obj, code_channel, src, options) + if not res: + print("call replay op failed for %s and get into call compile op" %(msg)) + compile_op(src, origin_func_name, op_info, options, code_channel, '{}') +""" + +COMPILE_OP_API = """ + msg = "start compile Ascend C Operator {}, kernel name is " + kernel_name + CommonUtility.print_compile_log("", msg, AscendCLogLevel.LOG_INFO) + op_type = "{}" + code_channel = get_code_channel(src, kernel_name, op_type, options) + op_info = OpInfo(kernel_name = kernel_name, op_type = op_type, inputs = __inputs__, outputs = __outputs__,\\ + attrs = __attrs__, impl_mode = impl_mode, origin_inputs=[{}], origin_outputs = [{}],\\ + param_type_dynamic = {}, mc2_ctx = {}, param_type_list = {}, init_value_list = {},\\ + output_shape_depend_on_compute = {}) + compile_op(src, origin_func_name, op_info, options, code_channel, '{}') +""" + +COMPILE_OP_API_BUILT_IN = """ + msg = "start compile Ascend C Operator {}, kernel name is " + kernel_name + CommonUtility.print_compile_log("", msg, AscendCLogLevel.LOG_INFO) + op_type = "{}" + code_channel = get_code_channel(src, kernel_name, op_type, options) + op_info = OpInfo(kernel_name = kernel_name, op_type = op_type, inputs = __inputs__, outputs = __outputs__,\\ + attrs = __attrs__, impl_mode = impl_mode, origin_inputs=[{}], origin_outputs = [{}],\\ + param_type_dynamic = {}, mc2_ctx = {}, param_type_list = {}, init_value_list = {},\\ + output_shape_depend_on_compute = {}) + + op_compile_option = '{}' + opp_path = os.environ.get('ASCEND_OPP_PATH') + dat_path = os.path.realpath(os.path.join(opp_path, "built-in", "op_impl", "ai_core", "tbe", "ascendc_impl.dat")) + if opp_path and os.path.exists(dat_path): + # dat file exists: built in hidden src file online compiling process. append vfs compile option in compile_op + abs_rel_kernel_src_path = "{}" + extend_options = {{}} + extend_options['opp_kernel_hidden_dat_path'] = dat_path + compile_op(abs_rel_kernel_src_path, origin_func_name, op_info, options, code_channel, op_compile_option,\\ + extend_options) + else: + raise RuntimeError("built-in opp compile, ascendc_impl.dat file path does not exist: %s" %(dat_path)) +""" + +SUP_API = """ +def {}({}, impl_mode=""): + __inputs__, __outputs__, __attrs__ = _build_args({}) + ret_str = check_op_cap("{}", "{}", __inputs__, __outputs__, __attrs__) + ret_dict = json.loads(ret_str) + err_code = ret_dict.get("ret_code") + sup = "Unknown" + reason = "Unknown reason" + if err_code is not None: + if err_code == 0: + sup = "True" + reason = "" + elif err_code == 1: + sup = "False" + reason = ret_dict.get("reason") + else: + sup = "Unknown" + reason = ret_dict.get("reason") + return sup, reason +""" +CAP_API = """ +def {}({}, impl_mode=""): + __inputs__, __outputs__, __attrs__ = _build_args({}) + result = check_op_cap("{}", "{}", __inputs__, __outputs__, __attrs__) + return result.decode("utf-8") +""" +GLZ_API = """ +@tbe_register.register_param_generalization("{}") +def {}_generalization({}, generalize_config=None): + __inputs__, __outputs__, __attrs__ = _build_args({}) + ret_str = generalize_op_params("{}", __inputs__, __outputs__, __attrs__, generalize_config) + return [json.loads(ret_str)] +""" + +ATTR_DEFAULT = { + "bool": "False", + "int": "0", + "float": "0.0", + "list_int": "[]", + "list_float": "[]", + "list_bool": "[]", + "list_list_int": "[[]]", + "str": "", +} + + +def optype_snake(origin_str): + temp_str = origin_str[0].lower() + origin_str[1:] + new_str = re.sub(r"([A-Z])", r"_\1", temp_str).lower() + return new_str + + +def optype_snake_ex(s): + snake_case = "" + for i, c in enumerate(s): + if i == 0: + snake_case += c.lower() + elif c.isupper(): + if s[i - 1] != "_": + if not s[i - 1].isupper(): + snake_case += "_" + elif s[i - 1].isupper() and (i + 1) < len(s) and s[i + 1].islower(): + snake_case += "_" + snake_case += c.lower() + else: + snake_case += c + return snake_case + + +class AdpBuilder(opdesc_parser.OpDesc): + def __init__(self: any, op_type: str): + self.argsdefv = [] + self.op_compile_option: str = "{}" + super().__init__(op_type) + + def write_adapt( + self: any, impl_path, path: str, op_compile_option_all: list = None + ): + self._build_paradefault() + if os.environ.get("BUILD_BUILTIN_OPP") != "1" and impl_path != "": + src_file = os.path.join(impl_path, self.op_file + ".cpp") + if not os.path.exists(src_file): + print( + f"[ERROR]: operator: {self.op_file} source file: {src_file} does not found, please check." + ) + return + out_path = os.path.abspath(path) + if self.dynamic_shape and not out_path.endswith("dynamic"): + out_path = os.path.join(path, "dynamic") + os.makedirs(out_path, exist_ok=True) + adpfile = os.path.join(out_path, self.op_file + ".py") + self._gen_op_compile_option(op_compile_option_all) + with os.fdopen(os.open(adpfile, const_var.WFLAGS, const_var.WMODES), "w") as fd: + self._write_head(fd) + self._write_argparse(fd) + self._write_impl(fd, impl_path) + if self.op_chk_support: + self._write_cap("check_supported", fd) + self._write_cap("get_op_support_info", fd) + if self.op_fmt_sel: + self._write_cap("op_select_format", fd) + self._write_cap("get_op_specific_info", fd) + if self.op_range_limit == "limited" or self.op_range_limit == "dynamic": + self._write_glz(fd) + + def _gen_op_compile_option(self: any, op_compile_option_all: list = None): + if op_compile_option_all is not None: + if self.op_type in op_compile_option_all: + self.op_compile_option = op_compile_option_all[self.op_type] + elif "__all__" in op_compile_option_all: + self.op_compile_option = op_compile_option_all["__all__"] + + def _ip_argpack(self: any, default: bool = True) -> list: + args = [] + for i in range(len(self.input_name)): + arg = self.input_name[i] + if default and self.argsdefv[i] is not None: + arg += "=" + self.argsdefv[i] + args.append(arg) + return args + + def _op_argpack(self: any, default: bool = True) -> list: + args = [] + argidx = len(self.input_name) + for i in range(len(self.output_name)): + arg = self.output_name[i] + if default and self.argsdefv[i + argidx] is not None: + arg += "=" + self.argsdefv[i + argidx] + args.append(arg) + return args + + def _attr_argpack(self: any, default: bool = True) -> list: + args = [] + argidx = len(self.input_name) + len(self.output_name) + for i in range(len(self.attr_list)): + att = self.attr_list[i] + arg = att + if default and self.argsdefv[i + argidx] is not None: + if self.attr_val.get(att).get("type") == "str": + arg += '="' + self.argsdefv[i + argidx] + '"' + elif self.attr_val.get(att).get("type") == "bool": + arg += "=" + self.argsdefv[i + argidx].capitalize() + else: + arg += "=" + self.argsdefv[i + argidx] + args.append(arg) + return args + + def _build_paralist(self: any, default: bool = True) -> str: + args = [] + args.extend(self._ip_argpack(default)) + args.extend(self._op_argpack(default)) + args.extend(self._attr_argpack(default)) + return ", ".join(args) + + def _io_parachk(self: any, types: list, type_name: str) -> list: + chk = [] + for iot in types: + if iot == "optional": + ptype = "OPTION" + else: + ptype = iot.upper() + chk.append("para_check.{}_{}".format(ptype, type_name)) + return chk + + def _attr_parachk(self: any) -> list: + chk = [] + for att in self.attr_list: + att_type = self.attr_val.get(att).get("type").upper() + chk.append("para_check.{}_ATTR_{}".format("OPTION", att_type)) + return chk + + def _build_parachk(self: any) -> str: + chk = [] + chk.extend(self._io_parachk(self.input_type, "INPUT")) + chk.extend(self._io_parachk(self.output_type, "OUTPUT")) + chk.extend(self._attr_parachk()) + chk.append("para_check.KERNEL_NAME") + return ", ".join(chk) + + def _build_virtual(self: any) -> str: + virt_exp = [] + for index in range(len(self.input_name)): + if self.input_virt.get(index) is None: + continue + val = [] + val.append('"param_name":"{}"'.format(self.input_name[index])) + val.append('"index":{}'.format(index)) + val.append('"dtype":"{}"'.format(self.input_dtype[index].split(",")[0])) + val.append('"format":"{}"'.format(self.input_fmt[index].split(",")[0])) + val.append('"ori_format":"{}"'.format(self.input_fmt[index].split(",")[0])) + val.append('"paramType":"optional"') + val.append('"shape":[1]') + val.append('"ori_shape":[1]') + virt_exp.append( + " " + self.input_name[index] + " = {" + ",".join(val) + "}" + ) + if len(virt_exp) > 0: + return "\n".join(virt_exp) + else: + return " # do ascendc build step" + + def _build_mc2_ctx(self: any): + if len(self.mc2_ctx) != 0: + return '["' + '", "'.join(self.mc2_ctx) + '"]' + return "[]" + + def _build_paradefault(self: any): + optional = False + argtypes = [] + argtypes.extend(self.input_type) + argtypes.extend(self.output_type) + in_idx = 0 + for atype in argtypes: + if atype == "optional": + optional = True + if optional: + self.argsdefv.append("None") + else: + self.argsdefv.append(None) + in_idx += 1 + for attr in self.attr_list: + atype = self.attr_val.get(attr).get("paramType") + if atype == "optional": + optional = True + attrval = self.attr_val.get(attr).get("defaultValue") + if attrval is not None: + optional = True + if type == "bool": + attrval = attrval.capitalize() + elif type == "str": + attrval = '"' + attrval + '"' + self.argsdefv.append(attrval) + continue + if optional: + self.argsdefv.append( + ATTR_DEFAULT.get(self.attr_val.get(attr).get("type")) + ) + else: + self.argsdefv.append(None) + + def _write_head(self: any, fd: object): + now = datetime.datetime.now() + curr_year = now.year + former_year = curr_year - 1 + fd.write( + IMPL_HEAD.format( + former_year, curr_year, self.input_ori_name, self.output_ori_name + ) + ) + + def _write_argparse(self: any, fd: object): + args = self._build_paralist(False) + fd.write("def _build_args({}):\n".format(args)) + fd.write(" __inputs__ = []\n") + fd.write(" for arg in [{}]:\n".format(", ".join(self.input_name))) + fd.write(" if arg != None:\n") + fd.write(" if isinstance(arg, (list, tuple)):\n") + fd.write(" if len(arg) == 0:\n") + fd.write(" continue\n") + fd.write(" __inputs__.append(arg[0])\n") + fd.write(" else:\n") + fd.write(" __inputs__.append(arg)\n") + fd.write(" else:\n") + fd.write(" __inputs__.append(arg)\n") + fd.write(" __outputs__ = []\n") + fd.write(" for arg in [{}]:\n".format(", ".join(self.output_name))) + fd.write(" if arg != None:\n") + fd.write(" if isinstance(arg, (list, tuple)):\n") + fd.write(" if len(arg) == 0:\n") + fd.write(" continue\n") + fd.write(" __outputs__.append(arg[0])\n") + fd.write(" else:\n") + fd.write(" __outputs__.append(arg)\n") + fd.write(" else:\n") + fd.write(" __outputs__.append(arg)\n") + fd.write(" __attrs__ = []\n") + for attr in self.attr_list: + fd.write(" if {} != None:\n".format(attr)) + fd.write(" attr = {}\n") + fd.write(' attr["name"] = "{}"\n'.format(attr)) + fd.write( + ' attr["dtype"] = "{}"\n'.format( + self.attr_val.get(attr).get("type") + ) + ) + fd.write(' attr["value"] = {}\n'.format(attr)) + fd.write(" __attrs__.append(attr)\n") + fd.write(" return __inputs__, __outputs__, __attrs__\n") + + def _get_kernel_source(self: any, kernel_src_dir, src_file, dir_snake, dir_ex): + src_ex = os.path.join(kernel_src_dir, dir_ex, src_file) + if os.path.exists(src_ex): + return src_ex + src = os.environ.get("BUILD_KERNEL_SRC") + if src and os.path.exists(src): + return src + src = os.path.join(kernel_src_dir, dir_snake, src_file) + if os.path.exists(src): + return src + src = os.path.join(kernel_src_dir, src_file) + if os.path.exists(src): + return src + src = os.path.join(kernel_src_dir, dir_snake, dir_snake + ".cpp") + if os.path.exists(src): + return src + src = os.path.join(kernel_src_dir, dir_ex, dir_ex + ".cpp") + if os.path.exists(src): + return src + src = os.path.join(kernel_src_dir, os.path.splitext(src_file)[0], src_file) + if os.path.exists(src): + return src + return src_ex + + def _write_impl(self: any, fd: object, impl_path: str = ""): + argsdef = self._build_paralist() + argsval = self._build_paralist(False) + pchk = self._build_parachk() + if len(self.kern_name) > 0: + kern_name = self.kern_name + else: + kern_name = self.op_intf + src = self.op_file + ".cpp" + virt_exprs = self._build_virtual() + fd.write( + IMPL_API.format( + self.op_type, + pchk, + self.op_intf, + argsdef, + kern_name, + virt_exprs, + argsval, + self.custom_compile_options, + self.custom_all_compile_options, + self.op_intf, + optype_snake_ex(self.op_type), + optype_snake(self.op_type), + src, + ) + ) + if self.op_replay_flag: + fd.write( + REPLAY_OP_API.format( + self.op_type, + kern_name, + self.op_file, + self.op_type, + self.op_file, + self.param_type_dynamic, + self.op_compile_option, + ) + ) + else: + if os.environ.get("BUILD_BUILTIN_OPP") == "1": + relative_kernel_src_path = os.path.realpath( + self._get_kernel_source( + impl_path, + src, + optype_snake(self.op_type), + optype_snake_ex(self.op_type), + ) + ) + # to match src path in .dat file system, turn relative path into absolute path + abs_rel_kernel_src_path = os.path.join( + "/", os.path.relpath(relative_kernel_src_path, impl_path) + ) + + # compiling hidden src file requires src path before packaging .dat file, + # hard code such src path to .py + fd.write( + COMPILE_OP_API_BUILT_IN.format( + self.op_type, + self.op_type, + ", ".join(self.input_name), + ", ".join(self.output_name), + self.param_type_dynamic, + self._build_mc2_ctx(), + self.input_type + self.output_type, + self.output_init_value, + self.output_shape_depend_on_compute, + self.op_compile_option, + abs_rel_kernel_src_path, + ) + ) + else: + fd.write( + COMPILE_OP_API.format( + self.op_type, + self.op_type, + ", ".join(self.input_name), + ", ".join(self.output_name), + self.param_type_dynamic, + self._build_mc2_ctx(), + self.input_type + self.output_type, + self.output_init_value, + self.output_shape_depend_on_compute, + self.op_compile_option, + ) + ) + + def _write_cap(self: any, cap_name: str, fd: object): + argsdef = self._build_paralist() + argsval = self._build_paralist(False) + if cap_name == "check_supported": + fd.write(SUP_API.format(cap_name, argsdef, argsval, cap_name, self.op_type)) + else: + fd.write(CAP_API.format(cap_name, argsdef, argsval, cap_name, self.op_type)) + + def _write_glz(self: any, fd: object): + argsdef = self._build_paralist() + argsval = self._build_paralist(False) + fd.write( + GLZ_API.format(self.op_type, self.op_intf, argsdef, argsval, self.op_type) + ) + + +def write_scripts( + cfgfile: str, + cfgs: dict, + dirs: dict, + ops: list = None, + op_compile_option: list = None, +): + batch_lists = cfgs.get(const_var.REPLAY_BATCH).split(";") + iterator_lists = cfgs.get(const_var.REPLAY_ITERATE).split(";") + file_map = {} + op_descs = opdesc_parser.get_op_desc( + cfgfile, + batch_lists, + iterator_lists, + AdpBuilder, + ops, + dirs.get(const_var.AUTO_GEN_DIR), + ) + for op_desc in op_descs: + op_desc.write_adapt( + dirs.get(const_var.CFG_IMPL_DIR), + dirs.get(const_var.CFG_OUT_DIR), + op_compile_option, + ) + file_map[op_desc.op_type] = op_desc.op_file + return file_map + + +class OpFileNotExistsError(Exception): + """File does not exist error.""" + + def __str__(self) -> str: + return ( + f"File aic-*-ops-info.ini does not exist in directory {super().__str__()}" + ) + + +def get_ops_info_files(opsinfo_dir: List[str]) -> List[str]: + """Get all ops info files.""" + ops_info_files = [] + for _dir in opsinfo_dir: + ops_info_files.extend(glob.glob(f"{_dir}/aic-*-ops-info.ini")) + return sorted(ops_info_files) + + +def parse_args(argv): + """Command line parameter parsing""" + parser = argparse.ArgumentParser() + parser.add_argument("argv", nargs="+") + parser.add_argument("--opsinfo-dir", nargs="*", default=None) + return parser.parse_args(argv) + + +if __name__ == "__main__": + args = parse_args(sys.argv) + + if len(args.argv) <= 6: + raise RuntimeError("arguments must greater equal than 6") + + rep_cfg = {} + rep_cfg[const_var.REPLAY_BATCH] = args.argv[2] + rep_cfg[const_var.REPLAY_ITERATE] = args.argv[3] + + cfg_dir = {} + cfg_dir[const_var.CFG_IMPL_DIR] = args.argv[4] + cfg_dir[const_var.CFG_OUT_DIR] = args.argv[5] + cfg_dir[const_var.AUTO_GEN_DIR] = args.argv[6] + + ops_infos = [] + if args.opsinfo_dir: + ops_infos.extend(get_ops_info_files(args.opsinfo_dir)) + if not ops_infos: + raise OpFileNotExistsError(args.opsinfo_dir) + else: + ops_infos.append(args.argv[1]) + + for ops_info in ops_infos: + write_scripts(cfgfile=ops_info, cfgs=rep_cfg, dirs=cfg_dir) diff --git a/csrc/deepep/ops2/cmake/util/ascendc_op_info.py b/csrc/deepep/ops2/cmake/util/ascendc_op_info.py new file mode 100755 index 00000000..a7540428 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/ascendc_op_info.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import os +import sys + +import opdesc_parser + +PYF_PATH = os.path.dirname(os.path.realpath(__file__)) + + +class OpInfo: + def __init__(self: any, op_type: str, cfg_file: str): + op_descs = opdesc_parser.get_op_desc( + cfg_file, [], [], opdesc_parser.OpDesc, [op_type] + ) + if op_descs is None or len(op_descs) != 1: + raise RuntimeError("cannot get op info of {}".format(op_type)) + self.op_desc = op_descs[0] + + def get_op_file(self: any): + return self.op_desc.op_file + + def get_op_intf(self: any): + return self.op_desc.op_intf + + def get_inputs_name(self: any): + return self.op_desc.input_ori_name + + def get_outputs_name(self: any): + return self.op_desc.output_ori_name + + +if __name__ == "__main__": + if len(sys.argv) <= 2: + raise RuntimeError("arguments must greater than 2") + op_info = OpInfo(sys.argv[1], sys.argv[2]) + print(op_info.get_op_file()) + print(op_info.get_op_intf()) + print(op_info.get_inputs_name()) + print(op_info.get_outputs_name()) diff --git a/csrc/deepep/ops2/cmake/util/ascendc_ops_config.py b/csrc/deepep/ops2/cmake/util/ascendc_ops_config.py new file mode 100755 index 00000000..8c5dd276 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/ascendc_ops_config.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import argparse +import glob +import json +import os +import sys + +import const_var + +BINARY_INFO_CONFIG_JSON = "binary_info_config.json" + + +def load_json(json_file: str): + with open(json_file, encoding="utf-8") as file: + json_content = json.load(file) + return json_content + + +def get_specified_suffix_file(root_dir, suffix): + specified_suffix = os.path.join(root_dir, "**/*.{}".format(suffix)) + all_suffix_files = glob.glob(specified_suffix, recursive=True) + return sorted(all_suffix_files) + + +def add_dict_key(dict_to_add, key, value): + if value is None: + return + dict_to_add[key] = value + + +def correct_format_mode(format_mode): + if format_mode == "FormatDefault": + return "nd_agnostic" + if format_mode == "FormatAgnostic": + return "static_nd_agnostic" + if format_mode == "FormatFixed": + return "normal" + return format_mode + + +def get_input_or_output_config(in_or_out): + param_dict = {} + name = in_or_out.get("name") + index = in_or_out.get("index") + param_type = in_or_out.get("paramType") + + format_match_mode = in_or_out.get("format_match_mode") + format_mode = correct_format_mode(format_match_mode) + + dtype_mode = in_or_out.get("dtype_match_mode") + if dtype_mode == "DtypeByte": + dtype_mode = "bit" + + add_dict_key(param_dict, "name", name) + add_dict_key(param_dict, "index", index) + add_dict_key(param_dict, "paramType", param_type) + add_dict_key(param_dict, "dtypeMode", dtype_mode) + add_dict_key(param_dict, "formatMode", format_mode) + return param_dict + + +def get_inputs_or_outputs_config(inputs_or_outputs): + if inputs_or_outputs is None: + return None + inputs_or_outputs_list = [] + + for in_or_out in inputs_or_outputs: + if isinstance(in_or_out, dict): + dict_param_config = get_input_or_output_config(in_or_out) + inputs_or_outputs_list.append(dict_param_config) + elif isinstance(in_or_out, list): + param_info = in_or_out[0] + list_param_config = get_input_or_output_config(param_info) + tmp_list = [list_param_config] + inputs_or_outputs_list.append(tmp_list) + return inputs_or_outputs_list + + +def gen_attrs_config(attrs): + attrs_list = [] + for attr in attrs: + attrs_dict = {} + name = attr.get("name") + mode = attr.get("mode") + add_dict_key(attrs_dict, "name", name) + add_dict_key(attrs_dict, "mode", mode) + attrs_list.append(attrs_dict) + return attrs_list + + +def get_params_config(support_info): + params_dict = {} + + inputs = support_info.get("inputs") + inputs_list = get_inputs_or_outputs_config(inputs) + params_dict["inputs"] = inputs_list + + outputs = support_info.get("outputs") + outputs_list = get_inputs_or_outputs_config(outputs) + params_dict["outputs"] = outputs_list + + attrs = support_info.get("attrs") + if attrs is not None: + attrs_list = gen_attrs_config(attrs) + params_dict["attrs"] = attrs_list + + return params_dict + + +def add_simplified_config( + op_type, support_info, core_type, task_ration, objfile, config +): + simplified_key = support_info.get("simplifiedKey") + + json_path = objfile.split(".")[0] + ".json" + + simple_cfg = config.get(BINARY_INFO_CONFIG_JSON) + op_cfg = simple_cfg.get(op_type) + if not op_cfg: + op_cfg = {"dynamicRankSupport": True} + + simplified_key_mode = support_info.get("simplifiedKeyMode") + add_dict_key(op_cfg, "simplifiedKeyMode", simplified_key_mode) + + optional_input_mode = support_info.get("optionalInputMode") + optional_output_mode = support_info.get("optionalOutputMode") + add_dict_key(op_cfg, "optionalInputMode", optional_input_mode) + if optional_output_mode is not None: + add_dict_key(op_cfg, "optionalOutputMode", optional_output_mode) + + params_info = get_params_config(support_info) + op_cfg["params"] = params_info + op_cfg["binaryList"] = [] + simple_cfg[op_type] = op_cfg + + bin_list = op_cfg.get("binaryList") + if core_type == 0 and task_ration == "tilingKey": + bin_list.append( + { + "coreType": core_type, + "simplifiedKey": simplified_key, + "multiKernelType": 1, + "binPath": objfile, + "jsonPath": json_path, + } + ) + else: + bin_list.append( + { + "coreType": core_type, + "simplifiedKey": simplified_key, + "binPath": objfile, + "jsonPath": json_path, + } + ) + + +def add_op_config(op_file, bin_info, config): + op_cfg = config.get(op_file) + if not op_cfg: + op_cfg = {"binList": []} + config[op_file] = op_cfg + op_cfg.get("binList").append(bin_info) + + +def gen_ops_config(json_file, soc, config): + core_type_map = { + "MIX": 0, + "AiCore": 1, + "VectorCore": 2, + "MIX_AICORE": 3, + "MIX_VECTOR_CORE": 4, + "MIX_AIV": 4, + } + contents = load_json(json_file) + if ("binFileName" not in contents) or ("supportInfo" not in contents): + return + json_base_name = os.path.basename(json_file) + op_dir = os.path.basename(os.path.dirname(json_file)) + + support_info = contents.get("supportInfo") + bin_name = contents.get("binFileName") + bin_suffix = contents.get("binFileSuffix") + core_type = contents.get("coreType") + task_ration = contents.get("taskRation") + core_type = core_type_map.get(core_type, -1) + if core_type == -1 and soc != "ascend310b": + raise Exception("[ERROR]: must set coreType in json when soc version is {soc}.") + + bin_file_name = bin_name + bin_suffix + op_type = bin_name.split("_")[0] + op_file = op_dir + ".json" + bin_info = {} + + add_dict_key(bin_info, "implMode", support_info.get("implMode")) + add_dict_key(bin_info, "int64Mode", support_info.get("int64Mode")) + add_dict_key(bin_info, "simplifiedKeyMode", support_info.get("simplifiedKeyMode")) + + simplified_key = support_info.get("simplifiedKey") + if simplified_key is not None: + bin_info["simplifiedKey"] = simplified_key + obj_file = os.path.join(soc, op_dir, bin_file_name) + add_simplified_config( + op_type, support_info, core_type, task_ration, obj_file, config + ) + + add_dict_key(bin_info, "dynamicParamMode", support_info.get("dynamicParamMode")) + bin_info["staticKey"] = support_info.get("staticKey") + bin_info["inputs"] = support_info.get("inputs") + bin_info["outputs"] = support_info.get("outputs") + if support_info.get("attrs"): + bin_info["attrs"] = support_info.get("attrs") + + add_dict_key(bin_info, "opMode", support_info.get("opMode")) + add_dict_key(bin_info, "optionalInputMode", support_info.get("optionalInputMode")) + add_dict_key(bin_info, "deterministic", support_info.get("deterministic")) + if support_info.get("optionalOutputMode") is not None: + add_dict_key( + bin_info, "optionalOutputMode", support_info.get("optionalOutputMode") + ) + + bin_info["binInfo"] = {"jsonFilePath": os.path.join(soc, op_dir, json_base_name)} + add_op_config(op_file, bin_info, config) + + +def check_single_op_is_void(root_dir): + for root, dirs, _ in os.walk(root_dir): + for sub_dir in dirs: + dir_path = os.path.join(root, sub_dir) + if len(os.listdir(dir_path)) == 0: + print(f"[ERROR] op {sub_dir}: not any obj compile success") + sys.exit(1) + + +def gen_all_config(root_dir, soc, out_dir, skip_binary_info_config): + suffix = "json" + config = {BINARY_INFO_CONFIG_JSON: {}} + check_single_op_is_void(root_dir) + all_json_files = get_specified_suffix_file(root_dir, suffix) + + for _json in all_json_files: + gen_ops_config(_json, soc, config) + file_path = soc + _json.split(soc)[1] + with open(_json, "r+") as f: + data = json.load(f) + data["filePath"] = file_path + f.seek(0) + json.dump(data, f, indent=" ") + f.truncate() + + for cfg_key in config.keys(): + if skip_binary_info_config and cfg_key == BINARY_INFO_CONFIG_JSON: + continue + cfg_file = os.path.join(out_dir, cfg_key) + with os.fdopen( + os.open(cfg_file, const_var.WFLAGS, const_var.WMODES), "w" + ) as fd: + json.dump(config.get(cfg_key), fd, indent=" ") + + +# Parse multiple soc_versions ops in single path. +def gen_all_soc_config(all_path): + soc_roots = glob.glob(os.path.join(all_path, "ascend*")) + + for soc_root in soc_roots: + soc = os.path.basename(soc_root) + gen_all_config(soc_root, soc, soc_root, True) + cfg_files = glob.glob(os.path.join(soc_root, "*.json")) + cfg_path = os.path.join(all_path, "config", soc) + os.makedirs(cfg_path, exist_ok=True) + for cfg_file in cfg_files: + new_file = os.path.join(cfg_path, os.path.basename(cfg_file)) + os.rename(cfg_file, new_file) + + +def args_prase(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-p", + "--path", + nargs="?", + required=True, + help="Parse the path of the json file.", + ) + + parser.add_argument( + "-s", "--soc", nargs="?", required=True, help="Parse the soc_version of ops." + ) + + parser.add_argument("-o", "--out", nargs="?", help="Output directory.") + + parser.add_argument( + "--skip-binary-info-config", + action="store_true", + help="binary_info_config.json file is not parsed.", + ) + + return parser.parse_args() + + +def main(): + args = args_prase() + if args.out is None: + out_dir = args.path + else: + out_dir = args.out + + gen_all_config(args.path, args.soc, out_dir, args.skip_binary_info_config) + + +if __name__ == "__main__": + main() diff --git a/csrc/deepep/ops2/cmake/util/ascendc_pack_kernel.py b/csrc/deepep/ops2/cmake/util/ascendc_pack_kernel.py new file mode 100755 index 00000000..430c1ba8 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/ascendc_pack_kernel.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import argparse +import glob +import json +import math +import os +import subprocess +import sys + +import ascendc_ops_config +import const_var +from tbe.tikcpp.log_utils import AscendCLogLevel, LogUtil + + +class PackKernel: + def __init__(self: any, args: any): + self.in_path = os.path.realpath(args.input_path) + self.out_path = os.path.realpath(args.output_path) + self.is_lib = args.enable_library + self.platform = args.platform + self.op_info = {} + self.file_info = {} + try: + os.makedirs(self.out_path, exist_ok=True) + except Exception as e: + LogUtil.print_compile_log( + "", + f"make {self.out_path} error: {e}!", + AscendCLogLevel.LOG_ERROR, + LogUtil.Option.NON_SOC, + ) + + def load_json(self: any, json_file: str): + with open(json_file, encoding="utf-8") as file: + json_content = json.load(file) + return json_content + + def get_symbol(self: any, name: str): + name = name.replace("/", "_") + return name.replace(".", "_") + + def ascendc_gen_object(self: any, in_file: str, soc: str): + sym = self.get_symbol("_binary_" + in_file) + out_file = os.path.join(self.out_path, sym + ".o") + # ascend610lite only support aarch64 + if soc == "ascend610lite": + try: + subprocess.run( + [ + "llvm-objcopy", + "--input-target", + "binary", + "--output-target", + "elf64-littleaarch64", + "--binary-architecture", + "aarch64", + in_file, + out_file, + ] + ) + except Exception as e: + LogUtil.print_compile_log( + "", + " ascend610lite execute objcopy fail!", + AscendCLogLevel.LOG_ERROR, + LogUtil.Option.NON_SOC, + ) + return None + return [sym + "_start", sym + "_end"] + uname = os.popen("uname -m").read().strip() + if self.platform is not None: + target_platform = self.platform + else: + target_platform = uname + try: + if target_platform == "x86_64": + subprocess.run( + [ + "llvm-objcopy", + "--input-target", + "binary", + "--output-target", + "elf64-x86-64", + "--binary-architecture", + "i386", + in_file, + out_file, + ] + ) + elif target_platform == "aarch64": + subprocess.run( + [ + "llvm-objcopy", + "--input-target", + "binary", + "--output-target", + "elf64-littleaarch64", + "--binary-architecture", + "aarch64", + in_file, + out_file, + ] + ) + else: + subprocess.run(["echo", "unsported environment!"]) + except Exception as e: + LogUtil.print_compile_log( + "", + f"{target_platform} execute objcopy error: {e}!", + AscendCLogLevel.LOG_ERROR, + LogUtil.Option.NON_SOC, + ) + return None + return [sym + "_start", sym + "_end"] + + def ascendc_get_config(self: any): + os.chdir(self.in_path) + soc_vers = os.listdir("config") + for soc in soc_vers: + bin_infos = glob.glob(os.path.join("config", soc, "*.json")) + cfgs = {} + for bin_info in bin_infos: + if bin_info.find("binary_info_config.json") > 0: + continue + jobj = self.load_json(bin_info) + for bin_cfg in jobj.get("binList"): + js_cfg = bin_cfg.get("binInfo").get("jsonFilePath") + op_type = os.path.basename(js_cfg).split("_")[0] + if cfgs.get(op_type) is None: + op_obj = {} + op_obj["obj"] = [] + op_obj["cfg"] = bin_info + cfgs[op_type] = op_obj + op_obj = cfgs.get(op_type) + op_obj.get("obj").append(js_cfg[:-5]) + self.file_info[soc] = cfgs + + def ascendc_pack_kernel(self: any): + for soc in self.file_info.keys(): + os.chdir(self.in_path) + op_cfgs = self.file_info.get(soc) + for op_type in op_cfgs.keys(): + op_obj = op_cfgs.get(op_type) + if self.op_info.get(op_type) is None: + op_info = {} + op_info["op_fun"] = ["nullptr", "nullptr"] + op_info["op_bin"] = {} + op_info["op_rkb"] = [] + self.op_info[op_type] = op_info + op_info = self.op_info.get(op_type) + op_bin = op_info.get("op_bin") + if op_bin.get(soc) is None: + op_bin[soc] = [] + op_bin[soc].append(self.ascendc_gen_object(op_obj["cfg"], soc)) + op_soc = op_bin.get(soc) + for objs in op_obj["obj"]: + op_soc.append(self.ascendc_gen_object(objs + ".json", soc)) + op_soc.append(self.ascendc_gen_object(objs + ".o", soc)) + + def ascendc_gen_header(self: any): + for op_type in self.op_info.keys(): + op_obj = self.op_info.get(op_type) + macro_op = ( + "#define {}_OP_RESOURCES std::make_tuple, \\\n" + " std::map>>, \\\n" + " std::vector>>({{{}}}, \\\n".format( + op_type, ", ".join(op_obj.get("op_fun")) + ) + ) + op_bin = op_obj.get("op_bin") + socs_res = [] + op_syms = [] + for soc in op_bin.keys(): + soc_res = '{{ "{}", {{'.format(soc) + soc_syms = op_bin.get(soc) + soc_pairs = [] + for pair_addr in soc_syms: + pair_addr1 = ["&" + s for s in pair_addr] + op_syms += pair_addr + soc_pairs.append( + " {{ {} }} ".format(", \\\n ".join(pair_addr1)) + ) + soc_res += ", \\\n ".join(soc_pairs) + soc_res += " } }" + socs_res.append(soc_res) + macro_op += " {{ {} }}, \\\n".format(", \\\n ".join(socs_res)) + macro_op += " {{ {} }})\n\n".format(", ".join(op_obj.get("op_rkb"))) + macro_str = '#define {}_RESOURCES {{{{"{}", {}}}}}'.format( + op_type, op_type, "{}_OP_RESOURCES".format(op_type) + ) + var_str = ( + "extern gert::OpImplRegisterV2 op_impl_register_optiling_{};\n".format( + op_type + ) + ) + if len(op_syms) > 0: + var_str += ( + "extern uint8_t " + ";\nextern uint8_t ".join(op_syms) + ";\n" + ) + head_file = os.path.join(self.out_path, "{}_op_resource.h".format(op_type)) + try: + with os.fdopen( + os.open(head_file, const_var.WFLAGS, const_var.WMODES), "w" + ) as fd: + fd.write("#include \n") + fd.write("#include \n") + fd.write("#include \n") + fd.write("#include \n") + fd.write('#include "graph/ascend_string.h"\n') + fd.write('#include "register/op_impl_registry.h"\n\n') + fd.write(var_str) + fd.write("\n") + fd.write(macro_op) + fd.write(macro_str) + except Exception as e: + LogUtil.print_compile_log( + "", + f"{op_type}_op_resource.h create error: {e}!", + AscendCLogLevel.LOG_ERROR, + LogUtil.Option.NON_SOC, + ) + + def ascendc_gen_lib(self: any): + out_lib = os.path.join(self.out_path, "libkernels.a") + if os.path.exists(out_lib): + os.remove(out_lib) + objs = glob.glob(os.path.join(self.out_path, "*.o")) + start = 0 + batch_size = 100 + for _ in range(math.ceil(len(objs) / batch_size)): + sub_objs = objs[start : start + batch_size] + start += batch_size + try: + subprocess.run(["ar", "qc", out_lib] + sub_objs) + subprocess.run(["ranlib", out_lib]) + except Exception as e: + LogUtil.print_compile_log( + "", + f"execute ar/ranlib command error: {e}!", + AscendCLogLevel.LOG_ERROR, + LogUtil.Option.NON_SOC, + ) + + def ascendc_gen_opsinfo(self: any): + ascendc_ops_config.gen_all_soc_config(self.in_path) + + +def args_parse(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", "--input-path", nargs="?", help="Input path of compile result." + ) + parser.add_argument( + "-o", "--output-path", nargs="?", help="Output path of compile result." + ) + parser.add_argument( + "-l", + "--enable-library", + nargs="?", + default=None, + help="Whether library is enabled.", + ) + parser.add_argument( + "-p", + "--platform", + nargs="?", + default=None, + help="target platform is x86_64 or aarch64.", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = args_parse() + kernel_packer = PackKernel(args) + if kernel_packer.is_lib is None: + kernel_packer.ascendc_gen_opsinfo() + kernel_packer.ascendc_get_config() + kernel_packer.ascendc_pack_kernel() + kernel_packer.ascendc_gen_header() + kernel_packer.ascendc_gen_lib() diff --git a/csrc/deepep/ops2/cmake/util/ascendc_pack_opregistry.py b/csrc/deepep/ops2/cmake/util/ascendc_pack_opregistry.py new file mode 100755 index 00000000..1acb7e45 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/ascendc_pack_opregistry.py @@ -0,0 +1,365 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import argparse +import glob +import math +import os +import shutil +import subprocess +import sys + +import const_var +from tbe.tikcpp.log_utils import AscendCLogLevel, LogUtil + + +class PackKernel: + def __init__(self: any, args: any): + self.in_path = os.path.realpath(args.input_path) + self.copy_path = os.path.realpath(args.copy_path) + self.out_path = os.path.realpath(args.output_path) + self.op_soc_ver = args.compute_unit.split("-") + self.vendor_name = args.vendor_name + self.framework_type = args.framework_type + self.platform = args.platform + self.op_info = {} + self.file_info = {} + if os.path.exists(self.copy_path): + try: + shutil.rmtree(self.copy_path) + except OSError as e: + LogUtil.print_compile_log( + "", + f"remove {self.copy_path} error!", + AscendCLogLevel.LOG_ERROR, + LogUtil.Option.NON_SOC, + ) + if os.path.exists(self.out_path): + try: + shutil.rmtree(self.out_path) + except OSError as e: + LogUtil.print_compile_log( + "", + f"remove {self.out_path} error!", + AscendCLogLevel.LOG_ERROR, + LogUtil.Option.NON_SOC, + ) + try: + os.makedirs(self.copy_path, exist_ok=True) + except Exception as e: + LogUtil.print_compile_log( + "", + f"make {self.copy_path} error: {e}!", + AscendCLogLevel.LOG_ERROR, + LogUtil.Option.NON_SOC, + ) + try: + os.makedirs(self.out_path, exist_ok=True) + except Exception as e: + LogUtil.print_compile_log( + "", + f"make {self.out_path} error: {e}!", + AscendCLogLevel.LOG_ERROR, + LogUtil.Option.NON_SOC, + ) + + def get_symbol(self: any, name: str): + name = name.replace("/", "_") + name = name.replace("-", "_") + return name.replace(".", "_") + + def ascendc_gen_object(self: any, in_file: str, path: str): + sym = self.get_symbol("_binary_" + in_file) + out_file = os.path.join(self.out_path, sym + ".o") + # ascend610lite only support aarch64 + if path.find("ascend610lite") != -1: + try: + subprocess.run( + [ + "llvm-objcopy", + "--input-target", + "binary", + "--output-target", + "elf64-littleaarch64", + "--binary-architecture", + "aarch64", + in_file, + out_file, + ] + ) + except Exception as e: + LogUtil.print_compile_log( + "", + " ascend610lite execute objcopy fail!", + AscendCLogLevel.LOG_ERROR, + LogUtil.Option.NON_SOC, + ) + return None + return [sym + "_start", sym + "_end"] + + uname = os.popen("uname -m").read().strip() + if self.platform is not None: + target_platform = self.platform + else: + target_platform = uname + try: + if target_platform == "x86_64": + subprocess.run( + [ + "llvm-objcopy", + "--input-target", + "binary", + "--output-target", + "elf64-x86-64", + "--binary-architecture", + "i386", + in_file, + out_file, + ] + ) + elif target_platform == "aarch64": + subprocess.run( + [ + "llvm-objcopy", + "--input-target", + "binary", + "--output-target", + "elf64-littleaarch64", + "--binary-architecture", + "aarch64", + in_file, + out_file, + ] + ) + else: + subprocess.run(["echo", "unsupported environment!"]) + except Exception as e: + LogUtil.print_compile_log( + "", + f"{target_platform} execute objcopy error: {e}!", + AscendCLogLevel.LOG_ERROR, + LogUtil.Option.NON_SOC, + ) + return None + return [sym + "_start", sym + "_end"] + + def ascendc_get_config(self: any): + os.chdir(self.copy_path) + current_directory = os.getcwd() + catalog_file = os.listdir(current_directory) + for catalog in catalog_file: + if catalog == "op_impl" or catalog == "framework": + files_dict = {} + for root, _, files in os.walk(catalog): + for file in files: + if ( + file.endswith(".json") + or file.endswith(".so") + or file.endswith(".cpp") + or file.endswith(".py") + or file.endswith(".o") + ): + file_path = os.path.join(root, file) + file_name = os.path.basename(file_path) + files_dict[file_name] = file_path + self.file_info[catalog] = files_dict + + def ascendc_pack_kernel(self: any): + op_info = {} + for files in self.file_info.keys(): + os.chdir(self.copy_path) + op_cfgs = self.file_info.get(files) + for file_name in op_cfgs.keys(): + op_info[file_name] = [] + path, filename = os.path.split(op_cfgs[file_name]) + op_info[file_name].append(os.path.join(self.vendor_name, path)) + op_info[file_name].append( + self.ascendc_gen_object(op_cfgs[file_name], path) + ) + self.op_info = op_info + + def ascendc_gen_header(self: any): + socs_res = [] + var_str = "" + macro_op = ( + "std::vector> __ascendc_op_info = \n" + ) + for file_name in self.op_info.keys(): + file_addr = self.op_info.get(file_name) + soc_pairs = [] + op_syms = [] + soc_res = ' {{ "{}", '.format(file_name) + soc_res += '"{}", '.format(file_addr[0]) + for pair_addr in file_addr[1]: + op_syms.append(pair_addr) + pair_addr1 = "&" + pair_addr + soc_pairs.append(pair_addr1) + soc_res += "{}, {}".format(soc_pairs[0], soc_pairs[1]) + soc_res += "}, \n" + socs_res.append(soc_res) + if len(op_syms) > 0: + var_str += "".join( + ["extern uint8_t {};\n".format(sym) for sym in op_syms] + ) + macro_op += "{{\n{}}}; \n".format("".join(socs_res)) + head_file = os.path.join(self.out_path, "ge_table_op_resource.h") + try: + with os.fdopen( + os.open(head_file, const_var.WFLAGS, const_var.WMODES), "w" + ) as fd: + fd.write("#include \n") + fd.write("#include \n") + fd.write("#include \n") + fd.write("#include \n") + fd.write('#include "graph/ascend_string.h"\n') + fd.write('#include "register/op_impl_registry.h"\n\n') + fd.write(var_str) + fd.write("\n") + fd.write("namespace AscendC {\n") + fd.write(macro_op) + fd.write("}\n") + except Exception as e: + LogUtil.print_compile_log( + "", + f"ge_table_op_resource.h create error: {e}!", + AscendCLogLevel.LOG_ERROR, + LogUtil.Option.NON_SOC, + ) + + def ascendc_gen_lib(self: any): + out_lib = os.path.join(self.out_path, "libopregistry.a") + if os.path.exists(out_lib): + os.remove(out_lib) + objs = glob.glob(os.path.join(self.out_path, "*.o")) + start = 0 + batch_size = 100 + for _ in range(math.ceil(len(objs) / batch_size)): + sub_objs = objs[start : start + batch_size] + start += batch_size + try: + subprocess.run(["ar", "qc", out_lib] + sub_objs) + subprocess.run(["ranlib", out_lib]) + except Exception as e: + LogUtil.print_compile_log( + "", + f"execute ar/ranlib command error: {e}!", + AscendCLogLevel.LOG_ERROR, + LogUtil.Option.NON_SOC, + ) + + def ascendc_copy_dir(self: any, src_dir: str, target_dir: str): + file_list = os.listdir(src_dir) + for file_name in file_list: + source_file = os.path.join(src_dir, file_name) + target_file = os.path.join(target_dir, file_name) + if os.path.isdir(source_file): + try: + shutil.copytree(source_file, target_file) + except Exception as e: + LogUtil.print_compile_log( + "", + f"copy {source_file} error: {e}!", + AscendCLogLevel.LOG_ERROR, + LogUtil.Option.NON_SOC, + ) + + def ascendc_copy_file(self: any, src_dir: str, target_dir: str): + file_list = os.listdir(src_dir) + for file_name in file_list: + source_file = os.path.join(src_dir, file_name) + if os.path.isfile(source_file): + try: + os.makedirs(target_dir, exist_ok=True) + except Exception as e: + LogUtil.print_compile_log( + "", + f"make {target_dir} error: {e}!", + AscendCLogLevel.LOG_ERROR, + LogUtil.Option.NON_SOC, + ) + try: + shutil.copy(source_file, target_dir) + except Exception as e: + LogUtil.print_compile_log( + "", + f"copy {source_file} error: {e}!", + AscendCLogLevel.LOG_ERROR, + LogUtil.Option.NON_SOC, + ) + + def ascendc_copy_func(self: any): + os.chdir(self.in_path) + framework_catalog = os.listdir("framework") + for catalog_file in framework_catalog: + if ( + catalog_file == "tf_plugin" + or catalog_file == "caffe_plugin" + or catalog_file == "onnx_plugin" + ): + source_dir = "op_kernel/tbe/op_info_cfg/ai_core" + dst_dir = os.path.join(self.copy_path, "framework", self.framework_type) + self.ascendc_copy_file(source_dir, dst_dir) + source_dir = os.path.join("framework", catalog_file) + dst_dir = os.path.join(self.copy_path, "framework", self.framework_type) + self.ascendc_copy_file(source_dir, dst_dir) + source_dir = "op_kernel/tbe/op_info_cfg/ai_core" + dst_dir = os.path.join(self.copy_path, "op_impl/ai_core/tbe/config") + self.ascendc_copy_dir(source_dir, dst_dir) + source_dir = "op_kernel/binary/dynamic" + dst_dir = os.path.join( + self.copy_path, "op_impl/ai_core/tbe", self.vendor_name + "_impl", "dynamic" + ) + self.ascendc_copy_file(source_dir, dst_dir) + for compute_unit in self.op_soc_ver: + source_dir = os.path.join("op_kernel/binary", compute_unit) + dst_dir = os.path.join( + self.copy_path, "op_impl/ai_core/tbe/kernel", compute_unit + ) + self.ascendc_copy_dir(source_dir, dst_dir) + source_dir = "op_kernel/binary/config" + dst_dir = os.path.join(self.copy_path, "op_impl/ai_core/tbe/kernel/config") + self.ascendc_copy_dir(source_dir, dst_dir) + so_file = "op_impl/ai_core/tbe/op_master_device/lib/libcust_opmaster.so" + if os.path.exists(so_file): + dst_dir = os.path.join( + self.copy_path, "op_impl/ai_core/tbe/op_master_device/lib" + ) + os.makedirs(dst_dir, exist_ok=True) + shutil.copy(so_file, dst_dir) + + +def args_parse(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", "--input-path", nargs="?", help="Input path of compile result." + ) + parser.add_argument( + "-c", "--copy-path", nargs="?", help="Copy path of compile result." + ) + parser.add_argument( + "-o", "--output-path", nargs="?", help="Output path of compile result." + ) + parser.add_argument("-n", "--vendor-name", nargs="?", help="Vendor name.") + parser.add_argument("-u", "--compute-unit", nargs="?", help="Compute unit.") + parser.add_argument( + "-t", "--framework-type", nargs="?", help="Framework type, eg:tensorflow." + ) + parser.add_argument( + "-p", + "--platform", + nargs="?", + default=None, + help="target platform is x86_64 or aarch64.", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = args_parse() + kernel_packer = PackKernel(args) + kernel_packer.ascendc_copy_func() + kernel_packer.ascendc_get_config() + kernel_packer.ascendc_pack_kernel() + kernel_packer.ascendc_gen_header() + kernel_packer.ascendc_gen_lib() diff --git a/csrc/deepep/ops2/cmake/util/ascendc_replay_build.py b/csrc/deepep/ops2/cmake/util/ascendc_replay_build.py new file mode 100755 index 00000000..e07545f5 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/ascendc_replay_build.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import os +import sys + +import const_var +import opdesc_parser +import replay_codegen +from replay_codegen import ReplayCodeGenParams + +PYF_PATH = os.path.dirname(os.path.realpath(__file__)) + + +class ReplayBuilder(opdesc_parser.OpDesc): + def __init__(self: any, op_type: str): + super().__init__(op_type) + + def gen_replay_source(self: any, impl_path: str, out_path: str, ops_product: str): + if not self.op_replay_flag: + print("{} replay not enabled".format(self.op_type)) + return + argn = len(self.input_name) + len(self.output_name) + 1 + if self.op_replay_batch: + print("{} replay in batch mode".format(self.op_type)) + else: + print("{} replay in normal mode".format(self.op_type)) + if impl_path.endswith("op_kernel"): + implf = os.path.join(impl_path, self.op_file + ".cpp") + tiling_file = os.path.join( + impl_path, "../op_host", self.op_file + "_tiling.h" + ) + else: + if self.dynamic_shape: + dyn_path = "dynamic" + else: + dyn_path = "" + implf = os.path.join(impl_path, dyn_path, self.op_file + ".cpp") + tiling_file = os.path.join( + impl_path, "../../op_tiling", self.op_file + "_tiling.h" + ) + rep_conf = replay_codegen.ReplayCodeGen( + ReplayCodeGenParams( + self.op_type, + implf, + tiling_file, + self.op_file, + self.op_intf, + argn, + self.op_replay_batch, + self.max_block_dim, + self.max_shape_size, + ) + ) + rep_conf.set_batch(self.op_replay_batch) + rep_conf.set_outdir(out_path) + rep_conf.gen_replay(ops_product) + + +def gen_replay( + cfgfile: str, cfgs: dict, dirs: dict, ops_product: str, ops: list = None +): + batch_lists = cfgs.get(const_var.REPLAY_BATCH).split(";") + iterator_lists = cfgs.get(const_var.REPLAY_ITERATE).split(";") + op_descs = opdesc_parser.get_op_desc( + cfgfile, batch_lists, iterator_lists, ReplayBuilder, ops + ) + for op_desc in op_descs: + op_desc.gen_replay_source( + dirs.get(const_var.CFG_IMPL_DIR), + dirs.get(const_var.CFG_OUT_DIR), + ops_product, + ) + + +if __name__ == "__main__": + if len(sys.argv) <= 6: + raise RuntimeError("arguments must greater than 6") + rep_cfg = {} + rep_cfg[const_var.REPLAY_BATCH] = sys.argv[2] + rep_cfg[const_var.REPLAY_ITERATE] = sys.argv[3] + rep_dir = {} + rep_dir[const_var.CFG_IMPL_DIR] = sys.argv[4] + rep_dir[const_var.CFG_OUT_DIR] = sys.argv[5] + gen_replay(sys.argv[1], rep_cfg, rep_dir, sys.argv[6]) diff --git a/csrc/deepep/ops2/cmake/util/batch_replay_impl.temp b/csrc/deepep/ops2/cmake/util/batch_replay_impl.temp new file mode 100755 index 00000000..0e883466 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/batch_replay_impl.temp @@ -0,0 +1,117 @@ +#include +#include +#include +#include +#include +#include +#include +#include "replay_def.h" +#include "code_gen.h" +#include "replay_fun.h" +#include "register/op_check.h" +#define __ASCENDC_REPLAY_CODE__ +#include + +using namespace std; +using namespace optiling; +using namespace AscendCReplay; + +extern "C" void __KERNEL_FUN__ (__ARGS_DEF__, const char *); +extern "C" int elf_batch_append(char *elf, uint32_t elfSize, char *jit, int kernum, char *atext[], int alen[], + int atlen, const char* kernelname[]); + +#define KERNEL_N 1 +#define ARG_N (__ARG_NUM__) +#define MAX_L (1024 * 1024 * 100) +#define MAX_E (1024 * 1024) + +int __KERNEL_FUN___replay___OPS_PRODUCT__(ReplayFuncParam& param, const int core_type) +{ + // gen type 1 : direct call codes 0: load .o file + if (param.gentype < 0 || param.gentype > 1) { + printf("Error: call replay gen type is %d, should only be 1 or 0\n", param.gentype); + return 0; + } else if (param.gentype == 1 && param.objptr == nullptr) { + printf("Error: call replay with direct call mode, but code obj addr is null\n"); + return 0; + } else if (param.gentype == 0 && param.output_kernel_file == nullptr) { + printf("Error: call replay with object file mode, but object file path is null\n"); + return 0; + } + // core_type 0:MIX 1:CUBE 2:VEC + if (core_type < 0 || core_type > 2) { + printf("Error: call replay core type is %d !\n", core_type); + return 0; + } + g_coreType = __CORE_TYPE__; + g_taskRation = param.task_ration; + g_tilingKey = param.tiling_key; + + unsigned char *buf, *jit; + char *kernel[KERNEL_N]; + int len[KERNEL_N]; + block_idx = 0; + block_num = param.block_dim; + g_ubBase = block_num; + uint8_t *code = (uint8_t *)malloc(MAX_L); + uint8_t *pos = code; + struct timespec tp1, tp2; + + clock_gettime(CLOCK_MONOTONIC, &tp1); + if (block_num > 32) { + printf("Error: block_num > 32\n"); + return 0; + } + //__OP_FOPEN__ + for (int i = 0; i < KERNEL_N; i++) { + //__OP_SET_KERNEL__ + for (int j = 0; j < ARG_N; j++) + AddArg(j, ARG_STEP * (j + 1)); +#ifdef FP_CEILING + SetCtrlFloatEnable(); +#else + SetCtrlFloatDisable(); +#endif + CodeInit(pos, true); + __KERNEL_FUN__(__KERNEL_ARGS__, param.tiling_data); + CodeEnd(); + kernel[i] = (char *)pos; + len[i] = CodeLen(); + pos += len[i]; + } + //__OP_FCLOSE__ + clock_gettime(CLOCK_MONOTONIC, &tp2); + buf = (unsigned char *)malloc(MAX_E); + int fd = open(param.entry_file, O_RDONLY); + if (fd < 0) { + printf("[error]: cannot find entry.o : %s\n", param.entry_file); + return 0; + } + uint32_t bufSize = read(fd, buf, MAX_E); + if (bufSize <= 0) { + printf("[error]: entry.o : %s is too small ! \n", param.entry_file); + } + close(fd); + jit = (unsigned char *)malloc(MAX_L); + printf("total code generated %ld\n", pos - code); + int sz = elf_batch_append((char *)buf, bufSize, (char *)jit, KERNEL_N, kernel, len, pos - code, ¶m.kernel_name); + if (tp1.tv_sec != tp2.tv_sec) { + printf("%ld NS\n", tp2.tv_nsec + 1000000000 - tp1.tv_nsec); + } else { + printf("%ld NS\n", tp2.tv_nsec - tp1.tv_nsec); + } + printf("new elf size %d\n", sz); + if (param.gentype == 0) { + fd = open(param.output_kernel_file, O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR); + (void)write(fd, jit, sz); + close(fd); + free(jit); + } else if (param.gentype == 1) { + *param.objptr = (char*)jit; + } + free(buf); + free(code); + return sz; +} + +REG_REPLAY_FUNC(__OPTYPE__, __OPS_PRODUCT__, __KERNEL_FUN___replay___OPS_PRODUCT__); diff --git a/csrc/deepep/ops2/cmake/util/code_channel_infer.py b/csrc/deepep/ops2/cmake/util/code_channel_infer.py new file mode 100755 index 00000000..c9042d4d --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/code_channel_infer.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import collections +import copy +import ctypes +import os +import shutil +import stat +import subprocess + +"""CODE_* is used to cube/vector api is called in operator code +CODE_MIX means both cube and vector api is called +CODE_CUBE means only cube api is called +CODE_VEC means only vector api is called +""" +CODE_MIX = 0 +CODE_CUBE = 1 +CODE_VEC = 2 + + +def _is_v220(op_product: str): + """return if current soc version is V220 + + Returns: + res: True means V220 + """ + if op_product == "ascend910_93" or op_product == "ascend910b": + return True + return False + + +InfoCodeChanelParams = collections.namedtuple( + "InfoCodeChanelParams", + [ + "src_file", + "tiling_header", + "kernel_name", + "outdir", + "op_product", + "compile_options", + ], +) + + +def infer_code_channel(params: InfoCodeChanelParams): + """get code channel for v220, return CODE_MIX if soc version is not V220 + + Args: + src_file (str): AscendC operator code file + src_file (str): AscendC operator tiling header file + kernel_name (str): kernel function name + optype (str): operator type + compile_options (list): compile options for bisheng cmd + + Raises: + Exception: if not exist L1/L0/UB if code, it's not a aicore code + + Returns: + res (int): CODE_MIX/CODE_CUBE/CODE_VEC + """ + if not _is_v220(params.op_product): + return CODE_MIX + return CODE_VEC diff --git a/csrc/deepep/ops2/cmake/util/const_var.py b/csrc/deepep/ops2/cmake/util/const_var.py new file mode 100755 index 00000000..bc8f33bf --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/const_var.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import os +import stat + +REPLAY_BATCH = "batch" +REPLAY_ITERATE = "iterate" +CFG_IMPL_DIR = "impl_dir" +CFG_OUT_DIR = "out_dir" +AUTO_GEN_DIR = "auto_gen_dir" +WFLAGS = os.O_WRONLY | os.O_CREAT | os.O_TRUNC +WMODES = stat.S_IWUSR | stat.S_IRUSR +SOC_MAP_EXT = { + "ascend310p": "Ascend310P3", + "ascend310b": "Ascend310B1", + "ascend910": "Ascend910A", + "ascend910b": "Ascend910B1", + "ascend910_93": "Ascend910_9391", + "ascend610lite": "Ascend610Lite", +} +BIN_CMD = "opc $1 --main_func={fun} --input_param={param} --soc_version={soc} \ +--output=$2 --impl_mode={impl} --simplified_key_mode=0 --op_mode=dynamic\n" +SET_PLOG_LEVEL_ERROR = "export ASCEND_GLOBAL_LOG_LEVEL=3\n" +SET_PLOG_STDOUT = "export ASCEND_SLOG_PRINT_TO_STDOUT=1\n" +SRC_ENV = """ +while true; do + case "$1" in + --kernel-src=*) + export BUILD_KERNEL_SRC=$(echo "$1" | cut -d"=" -f2-) + shift + ;; + -*) + shift + ;; + *) + break + ;; + esac +done +""" +CHK_CMD = """ +if ! test -f $2/{res_file} ; then + echo "$2/{res_file} not generated!" + exit 1 +fi +""" +ATTR_DEF_VAL = { + "str": "", + "int": 0, + "float": 0.0, + "bool": False, + "list_bool": [], + "list_int": [], + "list_float": [], + "list_list_int": [[]], +} + + +def conv_soc_ver(ver: str): + return SOC_MAP_EXT.get(ver) diff --git a/csrc/deepep/ops2/cmake/util/gen_impl_and_mrege_json.sh b/csrc/deepep/ops2/cmake/util/gen_impl_and_mrege_json.sh new file mode 100755 index 00000000..93d7ec84 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/gen_impl_and_mrege_json.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +project_path=$1 +build_path=$2 +vendor_name=customize +if [[ ! -d "$project_path" ]]; then + echo "[ERROR] No project path is provided" + exit 1 +fi + +if [[ ! -d "$build_path" ]]; then + echo "[ERROR] No build path is provided" + exit 1 +fi + +# copy aicpu kernel so operators +if [[ -d "${project_path}/cpukernel/aicpu_kernel_lib" ]]; then + cp -f ${project_path}/cpukernel/aicpu_kernel_lib/* ${build_path}/makepkg/packages/vendors/$vendor_name/op_impl/cpu/aicpu_kernel/impl + rm -rf ${project_path}/cpukernel/aicpu_kernel_lib +fi diff --git a/csrc/deepep/ops2/cmake/util/gen_ops_filter.sh b/csrc/deepep/ops2/cmake/util/gen_ops_filter.sh new file mode 100755 index 00000000..b06a4e9f --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/gen_ops_filter.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +if [[ -z "$1" ]]; then + echo -e "[ERROR] No source dir provided" + exit 1 +fi + +if [[ -z "$2" ]]; then + echo -e "[ERROR] No destination dir provided" + exit 1 +fi + +src=$1 +dest_file=$2/npu_supported_ops.json + +if [ -f "$dest_file" ];then + chmod u+w $dest_file +fi + +echo $* + +add_ops() { + name=$1 + isHeavy=$2 + file=$3 + grep -w "\"$name\"" ${file} >/dev/null + if [ $? == 0 ];then + return + fi + echo " \"${name}\": {" >> ${file} + echo " \"isGray\": false," >> ${file} + echo " \"isHeavy\": ${isHeavy}" >> ${file} + echo " }," >> ${file} +} + +echo "{" > ${dest_file} +ini_files=$(find ${src} -name "*.ini") +for file in ${ini_files} ; do + name=$(grep '^\[' ${file} | sed 's/\[//g' | sed 's/]//g' | sed 's/\r//g') + grep 'heavyOp.flag' ${file} >/dev/null + if [ $? == 0 ];then + isHeavy=$(grep 'heavyOp.flag' ${file} | awk -F= '{print $2}') + else + isHeavy="false" + fi + for op in ${name} ; do + add_ops ${op} "false" ${dest_file} + done +done +echo "}" >> ${dest_file} +file_count=$(cat ${dest_file} | wc -l) +line=$(($file_count-1)) +sed -i "${line}{s/,//g}" ${dest_file} + +chmod 640 "${dest_file}" +echo -e "[INFO] Succeed generated ${dest_file}" + +exit 0 diff --git a/csrc/deepep/ops2/cmake/util/gen_version_info.sh b/csrc/deepep/ops2/cmake/util/gen_version_info.sh new file mode 100755 index 00000000..8468f949 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/gen_version_info.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +ascend_install_dir=$1 +gen_file_dir=$2 + +# create version.info +compiler_version=$(grep "Version" -w ${ascend_install_dir}/compiler/version.info | awk -F = '{print $2}') +echo "custom_opp_compiler_version=${compiler_version}" > ${gen_file_dir}/version.info diff --git a/csrc/deepep/ops2/cmake/util/insert_op_info.py b/csrc/deepep/ops2/cmake/util/insert_op_info.py new file mode 100755 index 00000000..ca7562e7 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/insert_op_info.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +import json +import os +import stat +import sys + +import const_var + +if __name__ == "__main__": + if len(sys.argv) != 3: + print(sys.argv) + print("argv error, inert_op_info.py your_op_file lib_op_file") + sys.exit(2) + + with open(sys.argv[1], "r") as load_f: + insert_operator = json.load(load_f) + + all_operators = {} + if os.path.exists(sys.argv[2]): + if os.path.getsize(sys.argv[2]) != 0: + with open(sys.argv[2], "r") as load_f: + all_operators = json.load(load_f) + + for k in insert_operator.keys(): + if k in all_operators.keys(): + print("replace op:[", k, "] success") + else: + print("insert op:[", k, "] success") + all_operators[k] = insert_operator[k] + + with os.fdopen( + os.open(sys.argv[2], const_var.WFLAGS, const_var.WMODES), "w" + ) as json_file: + json_file.write(json.dumps(all_operators, indent=4)) diff --git a/csrc/deepep/ops2/cmake/util/insert_simplified_keys.py b/csrc/deepep/ops2/cmake/util/insert_simplified_keys.py new file mode 100755 index 00000000..599ebe97 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/insert_simplified_keys.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import argparse +import glob +import json +import os +import re +import sys + +import const_var + +DATA_TPYE_DICT = { + "float32": 0, + "float16": 1, + "int8": 2, + "int16": 6, + "uint16": 7, + "uint8": 4, + "int32": 3, + "int64": 9, + "uint32": 8, + "uint64": 10, + "bool": 12, + "double": 11, + "string": 13, + "dual": 14, + "dual": 15, + "complex64": 16, + "complex128": 17, + "qint8": 18, + "qint16": 19, + "qint32": 20, + "quint8": 21, + "quint16": 22, + "resource": 23, + "string": 24, + "dual": 25, + "variant": 26, + "bf16": 27, + "bfloat16": 27, + "undefined": 28, + "int4": 29, + "uint1": 30, + "int2": 31, +} + +FORMAT_DICT = { + "NCHW": 0, + "NHWC": 1, + "ND": 2, + "NC1HWC0": 3, + "FRACTAL_Z": 4, + "NC1C0HWPAD": 5, + "NHWC1C0": 6, + "FSR_NCHW": 7, + "FRACTAL_DECONV": 8, + "C1HWNC0": 9, + "FRACTAL_DECONV_TRANSPOSE": 10, + "FRACTAL_DECONV_SP_STRIDE_TRANS": 11, + "NC1HWC0_C04": 12, + "FRACTAL_Z_C04": 13, + "CHWN": 14, + "FRACTAL_DECONV_SP_STRIDE8_TRANS": 15, + "HWCN": 16, + "NC1KHKWHWC0": 17, + "BN_WEIGHT": 18, + "FILTER_HWCK": 19, + "HASHTABLE_LOOKUP_LOOKUPS": 20, + "HASHTABLE_LOOKUP_KEYS": 21, + "HASHTABLE_LOOKUP_VALUE": 22, + "HASHTABLE_LOOKUP_OUTPUT": 23, + "HASHTABLE_LOOKUP_HITS": 24, + "C1HWNCoC0": 25, + "MD": 26, + "NDHWC": 27, + "FRACTAL_ZZ": 28, + "FRACTAL_NZ": 29, + "NCDHW": 30, + "DHWCN": 31, + "NDC1HWC0": 32, + "FRACTAL_Z_3D": 33, + "CN": 34, + "NC": 35, + "DHWNC": 36, + "FRACTAL_Z_3D_TRANSPOSE": 37, + "FRACTAL_ZN_LSTM": 38, + "FRACTAL_Z_G": 39, + "RESERVED": 40, + "ALL": 41, + "NULL": 42, + "ND_RNN_BIAS": 43, + "FRACTAL_ZN_RNN": 44, + "NYUV": 45, + "NYUV_A": 46, +} + + +def load_json(json_file: str): + with open(json_file, encoding="utf-8") as file: + json_content = json.load(file) + return json_content + + +def get_specified_suffix_file(root_dir, suffix): + specified_suffix = os.path.join(root_dir, "**/*.{}".format(suffix)) + all_suffix_files = glob.glob(specified_suffix, recursive=True) + return all_suffix_files + + +def get_deterministic_value(support_info): + deterministic_key = "deterministic" + if deterministic_key not in support_info: + return 0 + deterministic_value = support_info.get(deterministic_key) + if deterministic_value == "true": + return 1 + else: + return 0 + + +def get_precision_value(support_info): + precision_key = "implMode" + precision_value = support_info.get(precision_key) + if precision_value == "high_performance": + _value = 1 + elif precision_value == "high_precision": + _value = 2 + else: + _value = 0 + return _value + + +def get_overflow_value(support_info): + return 0 + + +def get_parameters(info): + if info: + if "dtype" in info: + data_type = info["dtype"] + data_type_value = DATA_TPYE_DICT.get(data_type) + else: + data_type_value = 0 + if "format" in info: + _format = info["format"] + _format_value = FORMAT_DICT.get(_format) + else: + _format_value = 0 + else: + data_type_value = 0 + _format_value = 0 + return str(data_type_value), str(_format_value) + + +def get_dynamic_parameters(info): + # 动态输入时只需获取第一个参数 + return get_parameters(info[0]) + + +def get_all_parameters(support_info, _type): + result_list = list() + info_lists = support_info.get(_type) + if info_lists: + for _info in info_lists: + # 输入为列表时是动态输入 + if isinstance(_info, (list, tuple)): + data_type_value, _format_value = get_dynamic_parameters(_info) + else: + data_type_value, _format_value = get_parameters(_info) + result_list.append("{},{}".format(data_type_value, _format_value)) + return result_list + + +def get_all_input_parameters(support_info): + result = get_all_parameters(support_info, "inputs") + return "/".join(result) + + +def insert_content_into_file(input_file, content): + with open(input_file, "r+") as file: + lines = file.readlines() + for index, line in enumerate(lines): + match_result = re.search(r'"staticKey":', line) + if match_result: + count = len(line) - len(line.lstrip()) + new_content = "{}{}".format(" " * count, content) + # 插入到前一行,防止插入最后时还需要考虑是否添加逗号 + lines.insert(index, new_content) + break + file.seek(0) + file.write("".join(lines)) + + +def insert_simplified_keys(json_file): + contents = load_json(json_file) + # 不存在'binFileName'或者'supportInfo'字段时,非需要替换的解析json文件 + if ("binFileName" not in contents) or ("supportInfo" not in contents): + return + support_info = contents.get("supportInfo") + bin_file_name = contents.get("binFileName") + # 'simplifiedKey'字段已经存在时,直接返回,不重复生成 + if "simplifiedKey" in support_info: + return + op_type = bin_file_name.split("_")[0] + deterministic = str(get_deterministic_value(support_info)) + precision = str(get_precision_value(support_info)) + overflow = str(get_overflow_value(support_info)) + input_parameters = get_all_input_parameters(support_info) + key = "{}/d={},p={},o={}/{}/".format( + op_type, deterministic, precision, overflow, input_parameters + ) + result = '"simplifiedKey": "' + key + '",\n' + insert_content_into_file(json_file, result) + + +def insert_all_simplified_keys(root_dir): + suffix = "json" + all_json_files = get_specified_suffix_file(root_dir, suffix) + for _json in all_json_files: + insert_simplified_keys(_json) + + +def args_prase(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-p", + "--path", + nargs="?", + required=True, + help="Parse the path of the json file.", + ) + return parser.parse_args() + + +def main(): + args = args_prase() + insert_all_simplified_keys(args.path) + + +if __name__ == "__main__": + main() diff --git a/csrc/deepep/ops2/cmake/util/kernel_entry.py b/csrc/deepep/ops2/cmake/util/kernel_entry.py new file mode 100755 index 00000000..255266fd --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/kernel_entry.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + + +def gen_fun_def(title, kernel, argn, arg_type, arg_name): + entry = [] + entry.append(title) + entry.append(kernel) + entry.append("(") + args = [] + for i in range(0, argn): + args.append(arg_type + " " + arg_name + str(i)) + entry.append(", ".join(args)) + entry.append(")") + return " ".join(entry) + + +def gen_batch_kernel_body(fname, argn, arg_name): + body = [] + body.append("{") + fun = [] + fun.append(fname) + fun.append("(") + args = [] + for i in range(0, argn): + args.append(arg_name + str(i)) + fun.append(", ".join(args)) + fun.append(");") + body.append(" ".join(fun)) + body.append("}") + return "\n".join(body) + + +def gen_mc_kernel_body(kn, argn, arg_name, blknum): + body = [] + body.append("{") + body.append(" switch(block_idx) {") + for blk in range(0, blknum): + fun = [] + fun.append("{}_blk{:02d}".format(kn, blk)) + fun.append("(") + args = [] + for i in range(0, argn): + args.append(arg_name + str(i)) + fun.append(", ".join(args)) + fun.append(")") + body.append(" case {}: {}; break;".format(blk, " ".join(fun))) + body.append(" default: break;") + body.append(" }") + body.append("}") + return "\n".join(body) + + +def gen_proc_body(argn, arg_name): + body = [] + body.append("{") + args = [] + for i in range(0, argn): + args.append(arg_name + str(i)) + body.append("uint64_t __x = (uint64_t)" + " + (uint64_t)".join(args) + ";") + body.append('__asm__ ("NOP");') + body.append('__asm__ ("NOP");') + body.append('__asm__ ("NOP");') + body.append("}") + return "\n".join(body) + + +def batch_code_gen(kn, argn, argt): + codes = [] + kernel_name = kn + proc_name = kernel_name + "_percore" + arg_num = int(argn) + data_type = argt + arg_type = "__gm__ " + data_type + "* __restrict__" + arg_name = "arg" + kernel_title = 'extern "C" __global__ __aicore__ void' + proc_title = 'extern "C" __attribute__((noinline)) __aicore__ void' + codes.append("#ifndef __aicore__") + codes.append("#define __aicore__ [aicore]") + codes.append("#endif") + codes.append(gen_fun_def(proc_title, proc_name, arg_num, arg_type, arg_name) + ";") + codes.append(gen_fun_def(kernel_title, kernel_name, arg_num, arg_type, arg_name)) + codes.append(gen_batch_kernel_body(proc_name, arg_num, arg_name)) + codes.append(gen_fun_def(proc_title, proc_name, arg_num, arg_type, arg_name)) + codes.append(gen_proc_body(arg_num, arg_name)) + return "\n".join(codes) + "\n" + + +def mc_code_gen(kn, argn, argt, blknum): + codes = [] + kernel_name = kn + core_num = int(blknum) + arg_num = int(argn) + data_type = argt + arg_type = "__gm__ " + data_type + "* __restrict__" + arg_name = "arg" + kernel_title = 'extern "C" __global__ __aicore__ void' + proc_title = 'extern "C" __attribute__((noinline)) __aicore__ void' + codes.append("#ifndef __aicore__") + codes.append("#define __aicore__ [aicore]") + codes.append("#endif") + for i in range(0, core_num): + proc_name = "{}_blk{:02d}".format(kernel_name, i) + codes.append( + gen_fun_def(proc_title, proc_name, arg_num, arg_type, arg_name) + ";" + ) + codes.append(gen_fun_def(kernel_title, kernel_name, arg_num, arg_type, arg_name)) + codes.append(gen_mc_kernel_body(kernel_name, arg_num, arg_name, core_num)) + for i in range(0, core_num): + proc_name = "{}_blk{:02d}".format(kernel_name, i) + codes.append(gen_fun_def(proc_title, proc_name, arg_num, arg_type, arg_name)) + codes.append(gen_proc_body(arg_num, arg_name)) + return "\n".join(codes) + "\n" diff --git a/csrc/deepep/ops2/cmake/util/kernel_impl.temp b/csrc/deepep/ops2/cmake/util/kernel_impl.temp new file mode 100755 index 00000000..5079a104 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/kernel_impl.temp @@ -0,0 +1,10 @@ +#include +#include +#include +#include +#include +#include "replay_def.h" +#include "code_gen.h" +#include "replay_fun.h" +#define __ASCENDC_REPLAY_CODE__ +#include "__CCE_FILE__" diff --git a/csrc/deepep/ops2/cmake/util/makeself/COPYING b/csrc/deepep/ops2/cmake/util/makeself/COPYING new file mode 100755 index 00000000..d159169d --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/makeself/COPYING @@ -0,0 +1,339 @@ + GNU GENERAL PUBLIC LICENSE + Version 2, June 1991 + + Copyright (C) 1989, 1991 Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The licenses for most software are designed to take away your +freedom to share and change it. By contrast, the GNU General Public +License is intended to guarantee your freedom to share and change free +software--to make sure the software is free for all its users. This +General Public License applies to most of the Free Software +Foundation's software and to any other program whose authors commit to +using it. (Some other Free Software Foundation software is covered by +the GNU Lesser General Public License instead.) You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +this service if you wish), that you receive source code or can get it +if you want it, that you can change the software or use pieces of it +in new free programs; and that you know you can do these things. + + To protect your rights, we need to make restrictions that forbid +anyone to deny you these rights or to ask you to surrender the rights. +These restrictions translate to certain responsibilities for you if you +distribute copies of the software, or if you modify it. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must give the recipients all the rights that +you have. You must make sure that they, too, receive or can get the +source code. And you must show them these terms so they know their +rights. + + We protect your rights with two steps: (1) copyright the software, and +(2) offer you this license which gives you legal permission to copy, +distribute and/or modify the software. + + Also, for each author's protection and ours, we want to make certain +that everyone understands that there is no warranty for this free +software. If the software is modified by someone else and passed on, we +want its recipients to know that what they have is not the original, so +that any problems introduced by others will not reflect on the original +authors' reputations. + + Finally, any free program is threatened constantly by software +patents. We wish to avoid the danger that redistributors of a free +program will individually obtain patent licenses, in effect making the +program proprietary. To prevent this, we have made it clear that any +patent must be licensed for everyone's free use or not licensed at all. + + The precise terms and conditions for copying, distribution and +modification follow. + + GNU GENERAL PUBLIC LICENSE + TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + + 0. This License applies to any program or other work which contains +a notice placed by the copyright holder saying it may be distributed +under the terms of this General Public License. The "Program", below, +refers to any such program or work, and a "work based on the Program" +means either the Program or any derivative work under copyright law: +that is to say, a work containing the Program or a portion of it, +either verbatim or with modifications and/or translated into another +language. (Hereinafter, translation is included without limitation in +the term "modification".) Each licensee is addressed as "you". + +Activities other than copying, distribution and modification are not +covered by this License; they are outside its scope. The act of +running the Program is not restricted, and the output from the Program +is covered only if its contents constitute a work based on the +Program (independent of having been made by running the Program). +Whether that is true depends on what the Program does. + + 1. You may copy and distribute verbatim copies of the Program's +source code as you receive it, in any medium, provided that you +conspicuously and appropriately publish on each copy an appropriate +copyright notice and disclaimer of warranty; keep intact all the +notices that refer to this License and to the absence of any warranty; +and give any other recipients of the Program a copy of this License +along with the Program. + +You may charge a fee for the physical act of transferring a copy, and +you may at your option offer warranty protection in exchange for a fee. + + 2. You may modify your copy or copies of the Program or any portion +of it, thus forming a work based on the Program, and copy and +distribute such modifications or work under the terms of Section 1 +above, provided that you also meet all of these conditions: + + a) You must cause the modified files to carry prominent notices + stating that you changed the files and the date of any change. + + b) You must cause any work that you distribute or publish, that in + whole or in part contains or is derived from the Program or any + part thereof, to be licensed as a whole at no charge to all third + parties under the terms of this License. + + c) If the modified program normally reads commands interactively + when run, you must cause it, when started running for such + interactive use in the most ordinary way, to print or display an + announcement including an appropriate copyright notice and a + notice that there is no warranty (or else, saying that you provide + a warranty) and that users may redistribute the program under + these conditions, and telling the user how to view a copy of this + License. (Exception: if the Program itself is interactive but + does not normally print such an announcement, your work based on + the Program is not required to print an announcement.) + +These requirements apply to the modified work as a whole. If +identifiable sections of that work are not derived from the Program, +and can be reasonably considered independent and separate works in +themselves, then this License, and its terms, do not apply to those +sections when you distribute them as separate works. But when you +distribute the same sections as part of a whole which is a work based +on the Program, the distribution of the whole must be on the terms of +this License, whose permissions for other licensees extend to the +entire whole, and thus to each and every part regardless of who wrote it. + +Thus, it is not the intent of this section to claim rights or contest +your rights to work written entirely by you; rather, the intent is to +exercise the right to control the distribution of derivative or +collective works based on the Program. + +In addition, mere aggregation of another work not based on the Program +with the Program (or with a work based on the Program) on a volume of +a storage or distribution medium does not bring the other work under +the scope of this License. + + 3. You may copy and distribute the Program (or a work based on it, +under Section 2) in object code or executable form under the terms of +Sections 1 and 2 above provided that you also do one of the following: + + a) Accompany it with the complete corresponding machine-readable + source code, which must be distributed under the terms of Sections + 1 and 2 above on a medium customarily used for software interchange; or, + + b) Accompany it with a written offer, valid for at least three + years, to give any third party, for a charge no more than your + cost of physically performing source distribution, a complete + machine-readable copy of the corresponding source code, to be + distributed under the terms of Sections 1 and 2 above on a medium + customarily used for software interchange; or, + + c) Accompany it with the information you received as to the offer + to distribute corresponding source code. (This alternative is + allowed only for noncommercial distribution and only if you + received the program in object code or executable form with such + an offer, in accord with Subsection b above.) + +The source code for a work means the preferred form of the work for +making modifications to it. For an executable work, complete source +code means all the source code for all modules it contains, plus any +associated interface definition files, plus the scripts used to +control compilation and installation of the executable. However, as a +special exception, the source code distributed need not include +anything that is normally distributed (in either source or binary +form) with the major components (compiler, kernel, and so on) of the +operating system on which the executable runs, unless that component +itself accompanies the executable. + +If distribution of executable or object code is made by offering +access to copy from a designated place, then offering equivalent +access to copy the source code from the same place counts as +distribution of the source code, even though third parties are not +compelled to copy the source along with the object code. + + 4. You may not copy, modify, sublicense, or distribute the Program +except as expressly provided under this License. Any attempt +otherwise to copy, modify, sublicense or distribute the Program is +void, and will automatically terminate your rights under this License. +However, parties who have received copies, or rights, from you under +this License will not have their licenses terminated so long as such +parties remain in full compliance. + + 5. You are not required to accept this License, since you have not +signed it. However, nothing else grants you permission to modify or +distribute the Program or its derivative works. These actions are +prohibited by law if you do not accept this License. Therefore, by +modifying or distributing the Program (or any work based on the +Program), you indicate your acceptance of this License to do so, and +all its terms and conditions for copying, distributing or modifying +the Program or works based on it. + + 6. Each time you redistribute the Program (or any work based on the +Program), the recipient automatically receives a license from the +original licensor to copy, distribute or modify the Program subject to +these terms and conditions. You may not impose any further +restrictions on the recipients' exercise of the rights granted herein. +You are not responsible for enforcing compliance by third parties to +this License. + + 7. If, as a consequence of a court judgment or allegation of patent +infringement or for any other reason (not limited to patent issues), +conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot +distribute so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you +may not distribute the Program at all. For example, if a patent +license would not permit royalty-free redistribution of the Program by +all those who receive copies directly or indirectly through you, then +the only way you could satisfy both it and this License would be to +refrain entirely from distribution of the Program. + +If any portion of this section is held invalid or unenforceable under +any particular circumstance, the balance of the section is intended to +apply and the section as a whole is intended to apply in other +circumstances. + +It is not the purpose of this section to induce you to infringe any +patents or other property right claims or to contest validity of any +such claims; this section has the sole purpose of protecting the +integrity of the free software distribution system, which is +implemented by public license practices. Many people have made +generous contributions to the wide range of software distributed +through that system in reliance on consistent application of that +system; it is up to the author/donor to decide if he or she is willing +to distribute software through any other system and a licensee cannot +impose that choice. + +This section is intended to make thoroughly clear what is believed to +be a consequence of the rest of this License. + + 8. If the distribution and/or use of the Program is restricted in +certain countries either by patents or by copyrighted interfaces, the +original copyright holder who places the Program under this License +may add an explicit geographical distribution limitation excluding +those countries, so that distribution is permitted only in or among +countries not thus excluded. In such case, this License incorporates +the limitation as if written in the body of this License. + + 9. The Free Software Foundation may publish revised and/or new versions +of the General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + +Each version is given a distinguishing version number. If the Program +specifies a version number of this License which applies to it and "any +later version", you have the option of following the terms and conditions +either of that version or of any later version published by the Free +Software Foundation. If the Program does not specify a version number of +this License, you may choose any version ever published by the Free Software +Foundation. + + 10. If you wish to incorporate parts of the Program into other free +programs whose distribution conditions are different, write to the author +to ask for permission. For software which is copyrighted by the Free +Software Foundation, write to the Free Software Foundation; we sometimes +make exceptions for this. Our decision will be guided by the two goals +of preserving the free status of all derivatives of our free software and +of promoting the sharing and reuse of software generally. + + NO WARRANTY + + 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY +FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN +OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES +PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED +OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS +TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE +PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, +REPAIR OR CORRECTION. + + 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR +REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, +INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING +OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED +TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY +YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER +PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE +POSSIBILITY OF SUCH DAMAGES. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +convey the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along + with this program; if not, write to the Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +Also add information on how to contact you by electronic and paper mail. + +If the program is interactive, make it output a short notice like this +when it starts in an interactive mode: + + Gnomovision version 69, Copyright (C) year name of author + Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, the commands you use may +be called something other than `show w' and `show c'; they could even be +mouse-clicks or menu items--whatever suits your program. + +You should also get your employer (if you work as a programmer) or your +school, if any, to sign a "copyright disclaimer" for the program, if +necessary. Here is a sample; alter the names: + + Yoyodyne, Inc., hereby disclaims all copyright interest in the program + `Gnomovision' (which makes passes at compilers) written by James Hacker. + + , 1 April 1989 + Ty Coon, President of Vice + +This General Public License does not permit incorporating your program into +proprietary programs. If your program is a subroutine library, you may +consider it more useful to permit linking proprietary applications with the +library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. diff --git a/csrc/deepep/ops2/cmake/util/makeself/README.md b/csrc/deepep/ops2/cmake/util/makeself/README.md new file mode 100755 index 00000000..9d3d4b86 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/makeself/README.md @@ -0,0 +1,246 @@ +[![License: GPL v2](https://img.shields.io/badge/License-GPL%20v2-blue.svg)](https://www.gnu.org/licenses/old-licenses/gpl-2.0.en.html) +![Build Status](https://github.com/megastep/makeself/workflows/CI/badge.svg) + +# makeself - Make self-extractable archives on Unix + +[makeself.sh][1] is a small shell script that generates a self-extractable +compressed tar archive from a directory. The resulting file appears as a shell script +(many of those have a **.run** suffix), and can be launched as is. The archive +will then uncompress itself to a temporary directory and an optional arbitrary +command will be executed (for example an installation script). This is pretty +similar to archives generated with WinZip Self-Extractor in the Windows world. +Makeself archives also include checksums for integrity self-validation (CRC +and/or MD5/SHA256 checksums). + +The makeself.sh script itself is used only to create the archives from a +directory of files. The resultant archive is actually a compressed (using +gzip, bzip2, or compress) TAR archive, with a small shell script stub at the +beginning. This small stub performs all the steps of extracting the files, +running the embedded command, and removing the temporary files when done. +All the user has to do to install the software contained in such an +archive is to "run" the archive, i.e **sh nice-software.run**. I recommend +using the ".run" (which was introduced by some Makeself archives released by +Loki Software) or ".sh" suffix for such archives not to confuse the users, +so that they will know they are actually shell scripts (with quite a lot of binary data +attached to them though!). + +I am trying to keep the code of this script as portable as possible, i.e it is +not relying on any bash-specific features and only calls commands that are +installed on any functioning UNIX-compatible system. This script as well as +the archives it generates should run on any Unix flavor, with any compatible +Bourne shell, provided of course that the compression programs are available. + +As of version 2.1, Makeself has been rewritten and tested on the following +platforms : + + * Linux (all distributions) + * Sun Solaris (8 and above) + * HP-UX (tested on 11.0 and 11i on HPPA RISC) + * SCO OpenUnix and OpenServer + * IBM AIX 5.1L + * macOS (Darwin) + * SGI IRIX 6.5 + * FreeBSD + * UnicOS / Cray + * Cygwin (Windows) + +If you successfully run Makeself and/or archives created with it on another +system, then please [let me know][2]! + +Examples of publicly available archives made using makeself are : + + * Game patches and installers for [Id Software][3] games like Quake 3 for Linux or Return To Castle Wolfenstein ; + * All game patches released by [Loki Software][4] for the Linux version of popular games ; + * The [nVidia drivers][5] for Linux + * The installer for the Linux version of [Google Earth][6] + * The [VirtualBox][7] installers for Linux + * The [Makeself][1] distribution itself ;-) + * and countless others... + +**Important note for Apache users:** By default, most Web servers will think that Makeself archives are regular text files and thus they may show up as text in a Web browser. The correct way to prevent this is to add a MIME type for this file format, like so (in httpd.conf) : + +`AddType application/x-makeself .run` + +**Important note for certain GNU/Linux distributions:** Archives created with Makeself prior to v2.1.2 were using an old syntax for the _head_ and _tail_ Unix commands that is being progressively obsoleted in their GNU forms. Therefore you may have problems uncompressing some of these archives. A workaround for this is to set the environment variable $_POSIX2_VERSION to enable the old syntax, i.e. : + +`export _POSIX2_VERSION=199209` + +## Usage + +The syntax of makeself is the following: + +``` +makeself.sh [args] archive_dir file_name label startup_script [script_args] +``` + + * _args_ are optional options for Makeself. The available ones are : + + * **`--version`** : Prints the version number on stdout, then exits immediately + * **`--gzip`** : Use gzip for compression (the default on platforms on which gzip is commonly available, like Linux) + * **`--bzip2`** : Use bzip2 instead of gzip for better compression. The bzip2 command must be available in the command path. It is recommended that the archive prefix be set to something like '.bz2.run', so that potential users know that they'll need bzip2 to extract it. + * **`--pbzip2`** : Use pbzip2 instead of gzip for better and faster compression on machines having multiple CPUs. The pbzip2 command must be available in the command path. It is recommended that the archive prefix be set to something like '.bz2.run', so that potential users know that they'll need bzip2 to extract it. + * **`--xz`** : Use xz instead of gzip for better compression. The xz command must be available in the command path. It is recommended that the archive prefix be set to something like '.xz.run' for the archive, so that potential users know that they'll need xz to extract it. + * **`--lzo`** : Use lzop instead of gzip for better compression. The lzop command must be available in the command path. It is recommended that the archive prefix be set to something like `.lzo.run` for the archive, so that potential users know that they'll need lzop to extract it. + * **`--lz4`** : Use lz4 instead of gzip for better compression. The lz4 command must be available in the command path. It is recommended that the archive prefix be set to something like '.lz4.run' for the archive, so that potential users know that they'll need lz4 to extract it. + * **`--zstd`** : Use zstd instead of gzip for better compression. The zstd command must be available in the command path. It is recommended that the archive prefix be set to something like '.zstd.run' for the archive, so that potential users know that they'll need zstd to extract it. + * **`--pigz`** : Use pigz for compression. + * **`--base64`** : Encode the archive to ASCII in Base64 format instead of compressing (base64 command required). + * **`--gpg-encrypt`** : Encrypt the archive using `gpg -ac -z $COMPRESS_LEVEL`. This will prompt for a password to encrypt with. Assumes that potential users have `gpg` installed. + * **`--ssl-encrypt`** : Encrypt the archive using `openssl aes-256-cbc -a -salt`. This will prompt for a password to encrypt with. Assumes that the potential users have the OpenSSL tools installed. + * **`--compress`** : Use the UNIX `compress` command to compress the data. This should be the default on all platforms that don't have gzip available. + * **`--nocomp`** : Do not use any compression for the archive, which will then be an uncompressed TAR. + * **`--complevel`** : Specify the compression level for gzip, bzip2, pbzip2, zstd, xz, lzo or lz4. (defaults to 9) + * **`--threads`** : Specify the number of threads to be used by compressors that support parallelization. Omit to use compressor's default. Most useful (and required) for opting into xz's threading, usually with `--threads=0` for all available cores. pbzip2 and pigz are parallel by default, and setting this value allows limiting the number of threads they use. + * **`--notemp`** : The generated archive will not extract the files to a temporary directory, but in a new directory created in the current directory. This is better to distribute software packages that may extract and compile by themselves (i.e. launch the compilation through the embedded script). + * **`--current`** : Files will be extracted to the current directory, instead of in a subdirectory. This option implies `--notemp` above. + * **`--follow`** : Follow the symbolic links inside of the archive directory, i.e. store the files that are being pointed to instead of the links themselves. + * **`--append`** _(new in 2.1.x)_: Append data to an existing archive, instead of creating a new one. In this mode, the settings from the original archive are reused (compression type, label, embedded script), and thus don't need to be specified again on the command line. + * **`--header`** : Makeself uses a separate file to store the header stub, called `makeself-header.sh`. By default, it is assumed that it is stored in the same location as makeself.sh. This option can be used to specify its actual location if it is stored someplace else. + * **`--cleanup`** : Specify a script that is run when execution is interrupted or finishes successfully. The script is executed with the same environment and initial `script_args` as `startup_script`. + * **`--copy`** : Upon extraction, the archive will first extract itself to a temporary directory. The main application of this is to allow self-contained installers stored in a Makeself archive on a CD, when the installer program will later need to unmount the CD and allow a new one to be inserted. This prevents "Filesystem busy" errors for installers that span multiple CDs. + * **`--nox11`** : Disable the automatic spawning of a new terminal in X11. + * **`--nowait`** : When executed from a new X11 terminal, disable the user prompt at the end of the script execution. + * **`--nomd5`** and **`--nocrc`** : Disable the creation of a MD5 / CRC checksum for the archive. This speeds up the extraction process if integrity checking is not necessary. + * **`--sha256`** : Adds a SHA256 checksum for the archive. This is in addition to the MD5 / CRC checksums unless `--nomd5` is also used. + * **`--lsm` _file_** : Provide and LSM file to makeself, that will be embedded in the generated archive. LSM files are describing a software package in a way that is easily parseable. The LSM entry can then be later retrieved using the `--lsm` argument to the archive. An example of a LSM file is provided with Makeself. + * **`--tar-format opt`** : Specify the tar archive format (default is ustar); you may use any value accepted by your tar command (such as posix, v7, etc). + * **`--tar-extra opt`** : Append more options to the tar command line. + + For instance, in order to exclude the `.git` directory from the packaged archive directory using the GNU `tar`, one can use `makeself.sh --tar-extra "--exclude=.git" ...` + + * **`--keep-umask`** : Keep the umask set to shell default, rather than overriding when executing self-extracting archive. + * **`--packaging-date date`** : Use provided string as the packaging date instead of the current date. + * **`--license`** : Append a license file. + * **`--nooverwrite`** : Do not extract the archive if the specified target directory already exists. + * **`--help-header file`** : Add a header to the archive's `--help` output. + * `archive_dir` is the name of the directory that contains the files to be archived + * `file_name` is the name of the archive to be created + * `label` is an arbitrary text string describing the package. It will be displayed while extracting the files. + * `startup_script` is the command to be executed _from within_ the directory of extracted files. Thus, if you wish to execute a program contained in this directory, you must prefix your command with `./`. For example, `./program` will be fine. The `script_args` are additional arguments for this command. + +Here is an example, assuming the user has a package image stored in a **/home/joe/mysoft**, and he wants to generate a self-extracting package named +**mysoft.sh**, which will launch the "setup" script initially stored in /home/joe/mysoft : + +`makeself.sh /home/joe/mysoft mysoft.sh "Joe's Nice Software Package" ./setup +` + +Here is also how I created the [makeself.run][9] archive which contains the Makeself distribution : + +`makeself.sh --notemp makeself makeself.run "Makeself by Stephane Peter" echo "Makeself has extracted itself" ` + +Archives generated with Makeself can be passed the following arguments: + + * **`--keep`** : Prevent the files to be extracted in a temporary directory that will be removed after the embedded script's execution. The files will then be extracted in the current working directory and will stay here until you remove them. + * **`--verbose`** : Will prompt the user before executing the embedded command + * **`--target dir`** : Allows to extract the archive in an arbitrary place. + * **`--nox11`** : Do not spawn a X11 terminal. + * **`--confirm`** : Prompt the user for confirmation before running the embedded command. + * **`--info`** : Print out general information about the archive (does not extract). + * **`--lsm`** : Print out the LSM entry, if it is present. + * **`--list`** : List the files in the archive. + * **`--check`** : Check the archive for integrity using the embedded checksums. Does not extract the archive. + * **`--nochown`** : By default, a `chown -R` command is run on the target directory after extraction, so that all files belong to the current user. This is mostly needed if you are running as root, as tar will then try to recreate the initial user ownerships. You may disable this behavior with this flag. + * **`--tar`** : Run the tar command on the contents of the archive, using the following arguments as parameter for the command. + * **`--noexec`** : Do not run the embedded script after extraction. + * **`--noexec-cleanup`** : Do not run the embedded cleanup script. + * **`--nodiskspace`** : Do not check for available disk space before attempting to extract. + * **`--cleanup-args`** : Specify arguments to be passed to the cleanup script. Wrap value in quotes to specify multiple arguments. + +Any subsequent arguments to the archive will be passed as additional arguments to the embedded command. You must explicitly use the `--` special command-line construct before any such options to make sure that Makeself will not try to interpret them. + +## Startup Script + +The startup script must be a regular Shell script. + +Within the startup script, you can use the `$USER_PWD` variable to get the path of the folder from which the self-extracting script is executed. This is especially useful to access files that are located in the same folder as the script, as shown in the example below. + +`my-self-extracting-script.sh --fooBarFileParameter foo.bar` + +## Building and Testing + +Clone the git repo and execute `git submodule update --init --recursive` to obtain all submodules. + +* To make a release: `make` +* To run all tests: `make test` + +## Maven Usage + +Makeself is now supported by the following maven plugin [makeself-maven-plugin](https://github.com/hazendaz/makeself-maven-plugin). Please refer to project for usage and report any bugs in regards to maven plugin on that project. + +## License + +Makeself itself is covered by the [GNU General Public License][8] (GPL) version 2 and above. Archives generated by Makeself don't have to be placed under this license (although I encourage it ;-)), since the archive itself is merely data for Makeself. + +## Contributing + +I will gladly consider merging your pull requests on the [GitHub][10] repository. However, please keep the following in mind: + + * One of the main purposes of Makeself is portability. Do not submit patches that will break supported platforms. The more platform-agnostic, the better. + * Please explain clearly what the purpose of the patch is, and how you achieved it. + +## Download + +Get the latest official distribution [here][9] (version 2.4.2). + +The latest development version can be grabbed from [GitHub][10]. Feel free to submit any patches there through the fork and pull request process. + +## Version history + + * **v1.0:** Initial public release + * **v1.1:** The archive can be passed parameters that will be passed on to the embedded script, thanks to John C. Quillan + * **v1.2:** Cosmetic updates, support for bzip2 compression and non-temporary archives. Many ideas thanks to Francois Petitjean. + * **v1.3:** More patches from Bjarni R. Einarsson and Francois Petitjean: Support for no compression (`--nocomp`), script is no longer mandatory, automatic launch in an xterm, optional verbose output, and -target archive option to indicate where to extract the files. + * **v1.4:** Many patches from Francois Petitjean: improved UNIX compatibility, automatic integrity checking, support of LSM files to get info on the package at run time.. + * **v1.5.x:** A lot of bugfixes, and many other patches, including automatic verification through the usage of checksums. Version 1.5.5 was the stable release for a long time, even though the Web page didn't get updated ;-). Makeself was also officially made a part of the [Loki Setup installer][11], and its source is being maintained as part of this package. + * **v2.0:** Complete internal rewrite of Makeself. The command-line parsing was vastly improved, the overall maintenance of the package was greatly improved by separating the stub from makeself.sh. Also Makeself was ported and tested to a variety of Unix platforms. + * **v2.0.1:** First public release of the new 2.0 branch. Prior versions are officially obsoleted. This release introduced the `--copy` argument that was introduced in response to a need for the [UT2K3][12] Linux installer. + * **v2.1.0:** Big change : Makeself can now support multiple embedded tarballs, each stored separately with their own checksums. An existing archive can be updated with the `--append` flag. Checksums are also better managed, and the `--nochown` option for archives appeared. + * **v2.1.1:** Fixes related to the Unix compression (compress command). Some Linux distributions made the insane choice to make it unavailable, even though gzip is capable of uncompressing these files, plus some more bugfixes in the extraction and checksum code. + * **v2.1.2:** Some bug fixes. Use head -n to avoid problems with POSIX conformance. + * **v2.1.3:** Bug fixes with the command line when spawning terminals. Added `--tar`, `--noexec` for archives. Added `--nomd5` and `--nocrc` to avoid creating checksums in archives. The embedded script is now run through "eval". The `--info` output now includes the command used to create the archive. A man page was contributed by Bartosz Fenski. + * **v2.1.4:** Fixed `--info` output. Generate random directory name when extracting files to . to avoid problems. Better handling of errors with wrong permissions for the directory containing the files. Avoid some race conditions, Unset the $CDPATH variable to avoid problems if it is set. Better handling of dot files in the archive directory. + * **v2.1.5:** Made the md5sum detection consistent with the header code. Check for the presence of the archive directory. Added `--encrypt` for symmetric encryption through gpg (Eric Windisch). Added support for the digest command on Solaris 10 for MD5 checksums. Check for available disk space before extracting to the target directory (Andreas Schweitzer). Allow extraction to run asynchronously (patch by Peter Hatch). Use file descriptors internally to avoid error messages (patch by Kay Tiong Khoo). + * **v2.1.6:** Replaced one dot per file progress with a realtime progress percentage and a spinning cursor. Added `--noprogress` to prevent showing the progress during the decompression. Added `--target` dir to allow extracting directly to a target directory. (Guy Baconniere) + * **v2.2.0:** First major new release in years! Includes many bugfixes and user contributions. Please look at the [project page on Github][10] for all the details. + * **v2.3.0:** Support for archive encryption via GPG or OpenSSL. Added LZO and LZ4 compression support. Options to set the packaging date and stop the umask from being overridden. Optionally ignore check for available disk space when extracting. New option to check for root permissions before extracting. + * **v2.3.1:** Various compatibility updates. Added unit tests for Travis CI in the GitHub repo. New `--tar-extra`, `--untar-extra`, `--gpg-extra`, `--gpg-asymmetric-encrypt-sign` options. + * **v2.4.0:** Added optional support for SHA256 archive integrity checksums. + * **v2.4.2:** New --cleanup and --cleanup-args arguments for cleanup scripts. Added threading support for supported compressors. Now supports zstd compression. + * **v2.4.3:** Make explicit POSIX tar archives for increased compatibility. + * **v2.4.4:** Fixed various compatibility issues (no longer use POSIX tar archives), Github Actions to check on Solaris and FreeBSD. + * **v2.4.5:** Added `--tar-format` option to set the tar archive format (default is ustar) + +## Links + + * Check out the ["Loki Setup"][11] installer, used to install many Linux games and other applications, and of which I am the co-author. Since the demise of Loki, I am now the official maintainer of the project, and it is now being hosted here on GitHub. + * Bjarni R. Einarsson also wrote the **setup.sh** installer script, inspired by Makeself. [Check it out !][14] + +## Contact + +This script was written by [Stéphane Peter][15] (megastep at megastep.org). Any enhancements and suggestions are welcome. + +Contributions were included from John C. Quillan, Bjarni R. Einarsson, +Francois Petitjean, Ryan C. Gordon, and many contributors on GitHub. If you think I forgot +your name, don't hesitate to contact me. + +This project is now hosted on GitHub. Feel free to submit patches and bug reports on the [project page][10]. + +* * * + +[Stephane Peter][2] + + [1]: http://makeself.io/ + [2]: mailto:megastep@megastep.org + [3]: http://www.idsoftware.com/ + [4]: http://www.lokigames.com/products/myth2/updates.php3 + [5]: http://www.nvidia.com/ + [6]: http://earth.google.com/ + [7]: http://www.virtualbox.org/ + [8]: http://www.gnu.org/copyleft/gpl.html + [9]: https://github.com/megastep/makeself/releases/download/release-2.4.5/makeself-2.4.5.run + [10]: https://github.com/megastep/makeself + [11]: https://github.com/megastep/loki_setup/ + [12]: http://www.unrealtournament2003.com/ + [13]: http://www.icculus.org/ + [14]: http://bre.klaki.net/programs/setup.sh/ + [15]: https://stephanepeter.com/ diff --git a/csrc/deepep/ops2/cmake/util/makeself/VERSION b/csrc/deepep/ops2/cmake/util/makeself/VERSION new file mode 100755 index 00000000..59aa62c1 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/makeself/VERSION @@ -0,0 +1 @@ +2.4.5 diff --git a/csrc/deepep/ops2/cmake/util/makeself/make-release.sh b/csrc/deepep/ops2/cmake/util/makeself/make-release.sh new file mode 100755 index 00000000..65d698f2 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/makeself/make-release.sh @@ -0,0 +1,8 @@ +#!/bin/sh +# +# Create a distributable archive of the current version of Makeself + +VER=`cat VERSION` +mkdir -p /tmp/makeself-$VER release +cp -pPR makeself* test README.md COPYING VERSION .gitmodules /tmp/makeself-$VER/ +./makeself.sh --notemp /tmp/makeself-$VER release/makeself-$VER.run "Makeself v$VER" echo "Makeself has extracted itself" diff --git a/csrc/deepep/ops2/cmake/util/makeself/makeself-header.sh b/csrc/deepep/ops2/cmake/util/makeself/makeself-header.sh new file mode 100755 index 00000000..23ffc483 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/makeself/makeself-header.sh @@ -0,0 +1,660 @@ +cat << EOF > "$archname" +#!/bin/bash +# This script was generated using Makeself $MS_VERSION +# The license covering this archive and its contents, if any, is wholly independent of the Makeself license (GPL) +# 2022.3.19-Modified the MS_Help function and some options +# Huawei Technologies Co., Ltd. + +ORIG_UMASK=\`umask\` + +CRCsum="$CRCsum" +MD5="$MD5sum" +SHA="$SHAsum" +SIGNATURE="$Signature" +TMPROOT=\${TMPDIR:="\$HOME"} +if ! test -d "\$TMPROOT"; then + TMPROOT="\$PWD" +fi +export TMPDIR="\$TMPROOT" +USER_PWD="\$PWD" +if ! test -d "\$USER_PWD"; then + exit 1 +fi +export USER_PWD +ARCHIVE_DIR=\`dirname "\$0"\` +export ARCHIVE_DIR + +name_of_file="\$0 " +pwd_of_file="\$PWD" +label="$LABEL" +script="$SCRIPT" +scriptargs="$SCRIPTARGS" +cleanup_script="${CLEANUP_SCRIPT}" +licensetxt="$LICENSE" +helpheader='$HELPHEADER' +targetdir="$archdirname" +filesizes="$filesizes" +totalsize="$totalsize" +keep="$KEEP" +nooverwrite="$NOOVERWRITE" +quiet="n" +accept="n" +nodiskspace="n" +export_conf="$EXPORT_CONF" +decrypt_cmd="$DECRYPT_CMD" +skip="$SKIP" + +print_cmd_arg="" +if type printf > /dev/null; then + print_cmd="printf" +elif test -x /usr/ucb/echo; then + print_cmd="/usr/ucb/echo" +else + print_cmd="echo" +fi + +if test -d /usr/xpg4/bin; then + PATH=/usr/xpg4/bin:\$PATH + export PATH +fi + +if test -d /usr/sfw/bin; then + PATH=\$PATH:/usr/sfw/bin + export PATH +fi + +unset CDPATH + +MS_Printf() +{ + \$print_cmd \$print_cmd_arg "\$1" +} + +MS_PrintLicense() +{ + PAGER=\${PAGER:=more} + if test x"\$licensetxt" != x; then + PAGER_PATH=\`exec <&- 2>&-; which \$PAGER || command -v \$PAGER || type \$PAGER\` + if test -x "\$PAGER_PATH"; then + echo "\$licensetxt" | \$PAGER + else + echo "\$licensetxt" + fi + if test x"\$accept" != xy; then + while true + do + MS_Printf "Please type y to accept, n otherwise: " + read yn + if test x"\$yn" = xn; then + keep=n + eval \$finish; exit 1 + break; + elif test x"\$yn" = xy; then + break; + fi + done + fi + fi +} + +MS_diskspace() +{ + ( + df -kP "\$1" | tail -1 | awk '{ if (\$4 ~ /%/) {print \$3} else {print \$4} }' + ) +} + +MS_dd() +{ + blocks=\`expr \$3 / 1024\` + bytes=\`expr \$3 % 1024\` + # Test for ibs, obs and conv feature + if dd if=/dev/zero of=/dev/null count=1 ibs=512 obs=512 conv=sync 2> /dev/null; then + dd if="\$1" ibs=\$2 skip=1 obs=1024 conv=sync 2> /dev/null | \\ + { test \$blocks -gt 0 && dd ibs=1024 obs=1024 count=\$blocks ; \\ + test \$bytes -gt 0 && dd ibs=1 obs=1024 count=\$bytes ; } 2> /dev/null + else + dd if="\$1" bs=\$2 skip=1 2> /dev/null + fi +} + +MS_dd_Progress() +{ + if test x"\$noprogress" = xy; then + MS_dd "\$@" + return \$? + fi + file="\$1" + offset=\$2 + length=\$3 + pos=0 + bsize=4194304 + while test \$bsize -gt \$length; do + bsize=\`expr \$bsize / 4\` + done + blocks=\`expr \$length / \$bsize\` + bytes=\`expr \$length % \$bsize\` + ( + dd ibs=\$offset skip=1 2>/dev/null + pos=\`expr \$pos \+ \$bsize\` + MS_Printf " 0%% " 1>&2 + if test \$blocks -gt 0; then + while test \$pos -le \$length; do + dd bs=\$bsize count=1 2>/dev/null + pcent=\`expr \$length / 100\` + pcent=\`expr \$pos / \$pcent\` + if test \$pcent -lt 100; then + MS_Printf "\b\b\b\b\b\b\b" 1>&2 + if test \$pcent -lt 10; then + MS_Printf " \$pcent%% " 1>&2 + else + MS_Printf " \$pcent%% " 1>&2 + fi + fi + pos=\`expr \$pos \+ \$bsize\` + done + fi + if test \$bytes -gt 0; then + dd bs=\$bytes count=1 2>/dev/null + fi + MS_Printf "\b\b\b\b\b\b\b" 1>&2 + MS_Printf " 100%% " 1>&2 + ) < "\$file" +} + +MS_Help() +{ + cat << EOH >&2 +Usage: \$0 [options] +Options: + --help | -h Print this message + --info Print embedded info : title, default target directory, embedded script ... + --list Print the list of files in the archive + --check Checks integrity and version dependency of the archive + --quiet Quiet install mode, skip human-computer interactions + --nox11 Do not spawn an xterm + --noexec Do not run embedded script + --extract= Extract directly to a target directory (absolute or relative) + Usually used with --noexec to just extract files without running + --tar arg1 [arg2 ...] Access the contents of the archive through the tar command +\${helpheader} +EOH +} + +MS_Verify_Sig() +{ + GPG_PATH=\`exec <&- 2>&-; which gpg || command -v gpg || type gpg\` + MKTEMP_PATH=\`exec <&- 2>&-; which mktemp || command -v mktemp || type mktemp\` + test -x "\$GPG_PATH" || GPG_PATH=\`exec <&- 2>&-; which gpg || command -v gpg || type gpg\` + test -x "\$MKTEMP_PATH" || MKTEMP_PATH=\`exec <&- 2>&-; which mktemp || command -v mktemp || type mktemp\` + offset=\`head -n "\$skip" "\$1" | wc -c | tr -d " "\` + temp_sig=\`mktemp -t XXXXX\` + echo \$SIGNATURE | base64 --decode > "\$temp_sig" + gpg_output=\`MS_dd "\$1" \$offset \$totalsize | LC_ALL=C "\$GPG_PATH" --verify "\$temp_sig" - 2>&1\` + gpg_res=\$? + rm -f "\$temp_sig" + if test \$gpg_res -eq 0 && test \`echo \$gpg_output | grep -c Good\` -eq 1; then + if test \`echo \$gpg_output | grep -c \$sig_key\` -eq 1; then + test x"\$quiet" = xn && echo "GPG signature is good" >&2 + else + echo "GPG Signature key does not match" >&2 + exit 2 + fi + else + test x"\$quiet" = xn && echo "GPG signature failed to verify" >&2 + exit 2 + fi +} + +MS_Check() +{ + OLD_PATH="\$PATH" + PATH=\${GUESS_MD5_PATH:-"\$OLD_PATH:/bin:/usr/bin:/sbin:/usr/local/ssl/bin:/usr/local/bin:/opt/openssl/bin"} + MD5_ARG="" + MD5_PATH=\`exec <&- 2>&-; which md5sum || command -v md5sum || type md5sum\` + test -x "\$MD5_PATH" || MD5_PATH=\`exec <&- 2>&-; which md5 || command -v md5 || type md5\` + test -x "\$MD5_PATH" || MD5_PATH=\`exec <&- 2>&-; which digest || command -v digest || type digest\` + PATH="\$OLD_PATH" + + SHA_PATH=\`exec <&- 2>&-; which shasum || command -v shasum || type shasum\` + test -x "\$SHA_PATH" || SHA_PATH=\`exec <&- 2>&-; which sha256sum || command -v sha256sum || type sha256sum\` + + if test x"\$quiet" = xn; then + MS_Printf "Verifying archive integrity..." + fi + offset=\`head -n "\$skip" "\$1" | wc -c | tr -d " "\` + fsize=\`cat "\$1" | wc -c | tr -d " "\` + if test \$totalsize -ne \`expr \$fsize - \$offset\`; then + echo " Unexpected archive size." >&2 + exit 2 + fi + verb=\$2 + i=1 + for s in \$filesizes + do + crc=\`echo \$CRCsum | cut -d" " -f\$i\` + if test -x "\$SHA_PATH"; then + if test x"\`basename \$SHA_PATH\`" = xshasum; then + SHA_ARG="-a 256" + fi + sha=\`echo \$SHA | cut -d" " -f\$i\` + if test x"\$sha" = x0000000000000000000000000000000000000000000000000000000000000000; then + test x"\$verb" = xy && echo " \$1 does not contain an embedded SHA256 checksum." >&2 + else + shasum=\`MS_dd_Progress "\$1" \$offset \$s | eval "\$SHA_PATH \$SHA_ARG" | cut -b-64\`; + if test x"\$shasum" != x"\$sha"; then + echo "Error in SHA256 checksums: \$shasum is different from \$sha" >&2 + exit 2 + elif test x"\$quiet" = xn; then + MS_Printf " SHA256 checksums are OK." >&2 + fi + crc="0000000000"; + fi + fi + if test -x "\$MD5_PATH"; then + if test x"\`basename \$MD5_PATH\`" = xdigest; then + MD5_ARG="-a md5" + fi + md5=\`echo \$MD5 | cut -d" " -f\$i\` + if test x"\$md5" = x00000000000000000000000000000000; then + test x"\$verb" = xy && echo " \$1 does not contain an embedded MD5 checksum." >&2 + else + md5sum=\`MS_dd_Progress "\$1" \$offset \$s | eval "\$MD5_PATH \$MD5_ARG" | cut -b-32\`; + if test x"\$md5sum" != x"\$md5"; then + echo "Error in MD5 checksums: \$md5sum is different from \$md5" >&2 + exit 2 + elif test x"\$quiet" = xn; then + MS_Printf " MD5 checksums are OK." >&2 + fi + crc="0000000000"; verb=n + fi + fi + if test x"\$crc" = x0000000000; then + test x"\$verb" = xy && echo " \$1 does not contain a CRC checksum." >&2 + else + sum1=\`MS_dd_Progress "\$1" \$offset \$s | CMD_ENV=xpg4 cksum | awk '{print \$1}'\` + if test x"\$sum1" != x"\$crc"; then + echo "Error in checksums: \$sum1 is different from \$crc" >&2 + exit 2 + elif test x"\$quiet" = xn; then + MS_Printf " CRC checksums are OK." >&2 + fi + fi + i=\`expr \$i + 1\` + offset=\`expr \$offset + \$s\` + done + if test x"\$quiet" = xn; then + echo " All good." + fi +} + +MS_Decompress() +{ + if test x"\$decrypt_cmd" != x""; then + { eval "\$decrypt_cmd" || echo " ... Decryption failed." >&2; } | eval "$GUNZIP_CMD" + else + eval "$GUNZIP_CMD" + fi + + if test \$? -ne 0; then + echo " ... Decompression failed." >&2 + fi +} + +UnTAR() +{ + if test x"\$quiet" = xn; then + tar \$1vf - $UNTAR_EXTRA 2>&1 || { echo " ... Extraction failed." >&2; kill -15 \$$; } + else + tar \$1f - $UNTAR_EXTRA 2>&1 || { echo Extraction failed. >&2; kill -15 \$$; } + fi +} + +MS_exec_cleanup() { + if test x"\$cleanup" = xy && test x"\$cleanup_script" != x""; then + cleanup=n + cd "\$tmpdir" + eval "\"\$cleanup_script\" \$scriptargs \$cleanupargs" + fi +} + +MS_cleanup() +{ + echo 'Signal caught, cleaning up' >&2 + MS_exec_cleanup + cd "\$TMPROOT" + rm -rf "\$tmpdir" + eval \$finish; exit 15 +} + +Script_Args_Check() +{ + script_supported_args=\$(echo \${helpheader} | grep -o -E "\-\-[^ ]+" | awk -F"=" {'print \$1'}) + arg_to_test=\$(echo \$1|awk -F"=" {'print \$1'}) + + for arg in \${script_supported_args}; + do + if test x"\$arg_to_test" = x"\$arg" ;then + return + fi + done + + MS_Help + exit 1 +} + +finish=true +xterm_loop= +noprogress=$NOPROGRESS +nox11=$NOX11 +copy=$COPY +ownership=$OWNERSHIP +verbose=n +cleanup=y +cleanupargs= +sig_key= + +initargs="\$@" + +while [ -n "\$*" ] +do + case "\$1" in + -h | --help) + MS_Help + exit 0 + ;; + -q | --quiet) + quiet=y + noprogress=y + shift + ;; + --info) + echo Identification: "\$label" + echo Target directory: "\$targetdir" + echo Uncompressed size: $USIZE KB + echo Compression: $COMPRESS + if test x"$ENCRYPT" != x""; then + echo Encryption: $ENCRYPT + fi + echo Date of packaging: $DATE + echo Built with Makeself version $MS_VERSION + echo Build command was: "$MS_COMMAND" + if test x"\$script" != x; then + echo Script run after extraction: + echo " " \$script \$scriptargs + fi + if test x"$copy" = xcopy; then + echo "Archive will copy itself to a temporary location" + fi + if test x"$NEED_ROOT" = xy; then + echo "Root permissions required for extraction" + fi + if test x"$KEEP" = xy; then + echo "directory \$targetdir is permanent" + else + echo "\$targetdir will be removed after extraction" + fi + exit 0 + ;; + --list) + echo Target directory: \$targetdir + offset=\`head -n "\$skip" "\$0" | wc -c | tr -d " "\` + for s in \$filesizes + do + MS_dd "\$0" \$offset \$s | MS_Decompress | UnTAR t + offset=\`expr \$offset + \$s\` + done + exit 0 + ;; + --tar) + offset=\`head -n "\$skip" "\$0" | wc -c | tr -d " "\` + arg1="\$2" + shift 2 || { MS_Help; exit 1; } + for s in \$filesizes + do + MS_dd "\$0" \$offset \$s | MS_Decompress | tar "\$arg1" - "\$@" + offset=\`expr \$offset + \$s\` + done + exit 0 + ;; + --check) + MS_Check "\$0" y + scriptargs="\$scriptargs \$1" + shift + ;; + --noexec) + script="" + cleanup_script="" + shift + ;; + --extract=*) + keep=y + targetdir=\`echo \$1 | cut -d"=" -f2 \` + if ! shift; then MS_Help; exit 1; fi + ;; + --nox11) + nox11=y + shift + ;; + --xwin) + if test "$NOWAIT" = n; then + finish="echo Press Return to close this window...; read junk" + fi + xterm_loop=1 + shift + ;; + --phase2) + copy=phase2 + shift + ;; + --repack | --repack-path=*) + Script_Args_Check \$1 + scriptargs="\$scriptargs '\$1'" + shift + if [[ ! "\$1" =~ ^-.* ]]; then + scriptargs="\$scriptargs '\$1'" + shift + fi + ;; + *) + Script_Args_Check \$1 + scriptargs="\$scriptargs '\$1'" + shift + ;; + esac +done + +quiet_para="" +if test x"\$quiet" = xy; then + quiet_para="--quiet " +fi +scriptargs="--\$name_of_file""--\"\$pwd_of_file\""" \$quiet_para""\$scriptargs" + +if test x"\$quiet" = xy -a x"\$verbose" = xy; then + echo Cannot be verbose and quiet at the same time. >&2 + exit 1 +fi + +if test x"$NEED_ROOT" = xy -a \`id -u\` -ne 0; then + echo "Administrative privileges required for this archive (use su or sudo)" >&2 + exit 1 +fi + +if test x"\$copy" \!= xphase2; then + MS_PrintLicense +fi + +case "\$copy" in +copy) + tmpdir="\$TMPROOT"/makeself.\$RANDOM.\`date +"%y%m%d%H%M%S"\`.\$\$ + mkdir "\$tmpdir" || { + echo "Could not create temporary directory \$tmpdir" >&2 + exit 1 + } + SCRIPT_COPY="\$tmpdir/makeself" + echo "Copying to a temporary location..." >&2 + cp "\$0" "\$SCRIPT_COPY" + chmod +x "\$SCRIPT_COPY" + cd "\$TMPROOT" + exec "\$SCRIPT_COPY" --phase2 -- \$initargs + ;; +phase2) + finish="\$finish ; rm -rf \`dirname \$0\`" + ;; +esac + +if test x"\$nox11" = xn; then + if tty -s; then # Do we have a terminal? + : + else + if test x"\$DISPLAY" != x -a x"\$xterm_loop" = x; then # No, but do we have X? + if xset q > /dev/null 2>&1; then # Check for valid DISPLAY variable + GUESS_XTERMS="xterm gnome-terminal rxvt dtterm eterm Eterm xfce4-terminal lxterminal kvt konsole aterm terminology" + for a in \$GUESS_XTERMS; do + if type \$a >/dev/null 2>&1; then + XTERM=\$a + break + fi + done + chmod a+x \$0 || echo Please add execution rights on \$0 + if test \`echo "\$0" | cut -c1\` = "/"; then # Spawn a terminal! + exec \$XTERM -e "\$0 --xwin \$initargs" + else + exec \$XTERM -e "./\$0 --xwin \$initargs" + fi + fi + fi + fi +fi + +if test x"\$targetdir" = x.; then + tmpdir="." +else + if test x"\$keep" = xy; then + if test x"\$nooverwrite" = xy && test -d "\$targetdir"; then + echo "Target directory \$targetdir already exists, aborting." >&2 + exit 1 + fi + if test x"\$quiet" = xn; then + echo "Creating directory \$targetdir" >&2 + fi + tmpdir="\$targetdir" + dashp="-p" + else + tmpdir="\$TMPROOT/selfgz\$\$\$RANDOM" + dashp="" + fi + mkdir \$dashp "\$tmpdir" || { + echo 'Cannot create target directory' \$tmpdir >&2 + echo 'You should try option --extract=' >&2 + eval \$finish + exit 1 + } +fi + +location="\`pwd\`" +if test x"\$SETUP_NOCHECK" != x1; then + MS_Check "\$0" +fi +offset=\`head -n "\$skip" "\$0" | wc -c | tr -d " "\` + +if test x"\$verbose" = xy; then + MS_Printf "About to extract $USIZE KB in \$tmpdir ... Proceed ? [Y/n] " + read yn + if test x"\$yn" = xn; then + eval \$finish; exit 1 + fi +fi + +if test x"\$quiet" = xn; then + # Decrypting with openssl will ask for password, + # the prompt needs to start on new line + if test x"$ENCRYPT" = x"openssl"; then + echo "Decrypting and uncompressing \$label..." + else + MS_Printf "Uncompressing \$label" + fi +fi +res=3 +if test x"\$keep" = xn; then + trap MS_cleanup 1 2 3 15 +fi + +if test x"\$nodiskspace" = xn; then + leftspace=\`MS_diskspace "\$tmpdir"\` + if test -n "\$leftspace"; then + if test "\$leftspace" -lt $USIZE; then + echo + echo "Not enough space left in "\`dirname \$tmpdir\`" (\$leftspace KB) to decompress \$0 ($USIZE KB)" >&2 + if test x"\$keep" = xn; then + echo "Consider setting TMPDIR to a directory with more free space." + fi + eval \$finish; exit 1 + fi + fi +fi + +for s in \$filesizes +do + if MS_dd_Progress "\$0" \$offset \$s | MS_Decompress | ( cd "\$tmpdir"; umask \$ORIG_UMASK ; UnTAR xp ) 1>/dev/null; then + if test x"\$ownership" = xy; then + (cd "\$tmpdir"; chown -R \`id -u\` .; chgrp -R \`id -g\` .) + fi + else + echo >&2 + echo "Unable to decompress \$0" >&2 + eval \$finish; exit 1 + fi + offset=\`expr \$offset + \$s\` +done +if test x"\$quiet" = xn; then + echo +fi + +cd "\$tmpdir" +res=0 +if test x"\$script" != x; then + if test x"\$export_conf" = x"y"; then + MS_BUNDLE="\$0" + MS_LABEL="\$label" + MS_SCRIPT="\$script" + MS_SCRIPTARGS="\$scriptargs" + MS_ARCHDIRNAME="\$archdirname" + MS_KEEP="\$KEEP" + MS_NOOVERWRITE="\$NOOVERWRITE" + MS_COMPRESS="\$COMPRESS" + MS_CLEANUP="\$cleanup" + export MS_BUNDLE MS_LABEL MS_SCRIPT MS_SCRIPTARGS + export MS_ARCHDIRNAME MS_KEEP MS_NOOVERWRITE MS_COMPRESS + fi + + if test x"\$verbose" = x"y"; then + yn="x" + while test x"\$yn" != x -a x"\$yn" != xy -a x"\$yn" != xY -a x"\$yn" != xn -a x"\$yn" != xN + do + MS_Printf "OK to execute: \$script \$scriptargs \$* ? [Y/n] " + read yn + if test x"\$yn" = x -o x"\$yn" = xy -o x"\$yn" = xY; then + eval "\"\$script\" \$scriptargs \"\\\$@\""; res=\$?; + elif test x"\$yn" = xn -o x"\$yn" = xN; then + echo "Unable to decompress \$script ,because of aborting! ";res=\$? + else + echo "Input value is unacceptable,please try again." + fi + done + else + eval "\"\$script\" \$scriptargs \"\\\$@\""; res=\$? + fi + if test "\$res" -ne 0; then + test x"\$verbose" = xy && echo "The program '\$script' returned an error code (\$res)" >&2 + fi +fi + +MS_exec_cleanup + +if test x"\$keep" = xn; then + cd "\$TMPROOT" + rm -rf "\$tmpdir" +fi +eval \$finish; exit \$res +EOF diff --git a/csrc/deepep/ops2/cmake/util/makeself/makeself.1 b/csrc/deepep/ops2/cmake/util/makeself/makeself.1 new file mode 100755 index 00000000..81bf6e4f --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/makeself/makeself.1 @@ -0,0 +1,110 @@ +.TH "MAKESELF" "1" "2.4.5" +.SH "NAME" +makeself \- An utility to generate self-extractable archives. +.SH "SYNTAX" +.B makeself [\fIoptions\fP] archive_dir file_name label +.B [\fIstartup_script\fP] [\fIargs\fP] +.SH "DESCRIPTION" +This program is a free (GPL) utility designed to create self-extractable +archives from a directory. +.SH "OPTIONS" +The following options are supported. +.TP 15 +.B -v, --version +Prints out the makeself version number and exits. +.TP +.B -h, --help +Print out help information. +.TP +.B --tar-quietly +Suppress verbose output from the tar command +.TP +.B --quiet +Do not print any messages other than errors +.TP +.B --gzip +Compress using gzip (default if detected). +.TP +.B --bzip2 +Compress using bzip2. +.TP +.B --pbzip2 +Compress using pbzip2. +.TP +.B --xz +Compress using xz. +.TP +.B --lzo +Compress using lzop. +.TP +.B --lz4 +Compress using lz4. +.TP +.B --compress +Compress using the UNIX 'compress' command. +.TP +.B --nocomp +Do not compress the data. +.TP +.B --complevel lvl +Specify the compression level for gzip,bzip2,pbzui2,xz,lzo or lz4 +.TP +.B --notemp +The archive will create archive_dir in the current directory and +uncompress in ./archive_dir. +.TP +.B --copy +Upon extraction, the archive will first copy itself to a temporary directory. +.TP +.B --append +Append more files to an existing makeself archive. The label and startup scripts will then be ignored. +.TP +.B --current +Files will be extracted to the current directory. Both --current and --target dir imply --notemp. +.TP +.B --target dir +Extract directly to a target directory. Directory path can be either absolute or relative. +.TP +.B --header file +Specify location of the header script. +.TP +.B --cleanup file +Specify a cleanup script that executes on interrupt and when finished successfully. +.TP +.B --follow +Follow the symlinks in the archive. +.TP +.B --noprogress +Do not show the progress during the decompression. +.TP +.B --nox11 +Disable automatic spawn of an xterm if running in X11. +.TP +.B --nowait +Do not wait for user input after executing embedded program from an xterm. +.TP +.B --nomd5 +Do not create a MD5 checksum for the archive. +.TP +.B --nocrc +Do not create a CRC32 checksum for the archive. +.TP +.B --lsm file +LSM file describing the package. +.B --packaging-date date +Use provided string as the packaging date instead of the current date. +.SH "EXAMPLES" +Here is an example, assuming the user has a package image stored in a /home/joe/mysoft, +and he wants to generate a self-extracting package named mysoft.sh, which will launch +the "setup" script initially stored in /home/joe/mysoft: +.TP +makeself.sh /home/joe/mysoft mysoft.sh "Joe's Nice Software Package" ./setup +.TP +Here is also how I created the makeself.run archive which contains the Makeself distribution: +.TP +makeself.sh --notemp makeself makeself.run "Makeself by Stephane Peter" echo "Makeself has extracted itself" +.SH "AUTHORS" +Makeself has been written by Stéphane Peter . +.BR +This man page was originally written by Bartosz Fenski for the +Debian GNU/Linux distribution (but it may be used by others). diff --git a/csrc/deepep/ops2/cmake/util/makeself/makeself.lsm b/csrc/deepep/ops2/cmake/util/makeself/makeself.lsm new file mode 100755 index 00000000..802cada3 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/makeself/makeself.lsm @@ -0,0 +1,16 @@ +Begin3 +Title: makeself.sh +Version: 2.4.5 +Description: makeself.sh is a shell script that generates a self-extractable + tar.gz archive from a directory. The resulting file appears as a shell + script, and can be launched as is. The archive will then uncompress + itself to a temporary directory and an arbitrary command will be + executed (for example an installation script). This is pretty similar + to archives generated with WinZip Self-Extractor in the Windows world. +Keywords: Installation archive tar winzip +Author: Stephane Peter (megastep@megastep.org) +Maintained-by: Stephane Peter (megastep@megastep.org) +Original-site: https://makeself.io/ +Platform: Unix +Copying-policy: GPL +End diff --git a/csrc/deepep/ops2/cmake/util/makeself/makeself.sh b/csrc/deepep/ops2/cmake/util/makeself/makeself.sh new file mode 100755 index 00000000..60ced4a8 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/makeself/makeself.sh @@ -0,0 +1,822 @@ +#!/bin/sh +# +# Makeself version 2.4.x +# by Stephane Peter +# +# Utility to create self-extracting tar.gz archives. +# The resulting archive is a file holding the tar.gz archive with +# a small Shell script stub that uncompresses the archive to a temporary +# directory and then executes a given script from within that directory. +# +# Makeself home page: https://makeself.io/ +# +# Version 2.0 is a rewrite of version 1.0 to make the code easier to read and maintain. +# +# Version history : +# - 1.0 : Initial public release +# - 1.1 : The archive can be passed parameters that will be passed on to +# the embedded script, thanks to John C. Quillan +# - 1.2 : Package distribution, bzip2 compression, more command line options, +# support for non-temporary archives. Ideas thanks to Francois Petitjean +# - 1.3 : More patches from Bjarni R. Einarsson and Francois Petitjean: +# Support for no compression (--nocomp), script is no longer mandatory, +# automatic launch in an xterm, optional verbose output, and -target +# archive option to indicate where to extract the files. +# - 1.4 : Improved UNIX compatibility (Francois Petitjean) +# Automatic integrity checking, support of LSM files (Francois Petitjean) +# - 1.5 : Many bugfixes. Optionally disable xterm spawning. +# - 1.5.1 : More bugfixes, added archive options -list and -check. +# - 1.5.2 : Cosmetic changes to inform the user of what's going on with big +# archives (Quake III demo) +# - 1.5.3 : Check for validity of the DISPLAY variable before launching an xterm. +# More verbosity in xterms and check for embedded command's return value. +# Bugfix for Debian 2.0 systems that have a different "print" command. +# - 1.5.4 : Many bugfixes. Print out a message if the extraction failed. +# - 1.5.5 : More bugfixes. Added support for SETUP_NOCHECK environment variable to +# bypass checksum verification of archives. +# - 1.6.0 : Compute MD5 checksums with the md5sum command (patch from Ryan Gordon) +# - 2.0 : Brand new rewrite, cleaner architecture, separated header and UNIX ports. +# - 2.0.1 : Added --copy +# - 2.1.0 : Allow multiple tarballs to be stored in one archive, and incremental updates. +# Added --nochown for archives +# Stopped doing redundant checksums when not necessary +# - 2.1.1 : Work around insane behavior from certain Linux distros with no 'uncompress' command +# Cleaned up the code to handle error codes from compress. Simplified the extraction code. +# - 2.1.2 : Some bug fixes. Use head -n to avoid problems. +# - 2.1.3 : Bug fixes with command line when spawning terminals. +# Added --tar for archives, allowing to give arbitrary arguments to tar on the contents of the archive. +# Added --noexec to prevent execution of embedded scripts. +# Added --nomd5 and --nocrc to avoid creating checksums in archives. +# Added command used to create the archive in --info output. +# Run the embedded script through eval. +# - 2.1.4 : Fixed --info output. +# Generate random directory name when extracting files to . to avoid problems. (Jason Trent) +# Better handling of errors with wrong permissions for the directory containing the files. (Jason Trent) +# Avoid some race conditions (Ludwig Nussel) +# Unset the $CDPATH variable to avoid problems if it is set. (Debian) +# Better handling of dot files in the archive directory. +# - 2.1.5 : Made the md5sum detection consistent with the header code. +# Check for the presence of the archive directory +# Added --encrypt for symmetric encryption through gpg (Eric Windisch) +# Added support for the digest command on Solaris 10 for MD5 checksums +# Check for available disk space before extracting to the target directory (Andreas Schweitzer) +# Allow extraction to run asynchronously (patch by Peter Hatch) +# Use file descriptors internally to avoid error messages (patch by Kay Tiong Khoo) +# - 2.1.6 : Replaced one dot per file progress with a realtime progress percentage and a spinning cursor (Guy Baconniere) +# Added --noprogress to prevent showing the progress during the decompression (Guy Baconniere) +# Added --target dir to allow extracting directly to a target directory (Guy Baconniere) +# - 2.2.0 : Many bugfixes, updates and contributions from users. Check out the project page on Github for the details. +# - 2.3.0 : Option to specify packaging date to enable byte-for-byte reproducibility. (Marc Pawlowsky) +# - 2.4.0 : Optional support for SHA256 checksums in archives. +# - 2.4.2 : Add support for threads for several compressors. (M. Limber) +# Added zstd support. +# - 2.4.3 : Make explicit POSIX tar archives for increased compatibility. +# - 2.4.5 : Added --tar-format to override ustar tar archive format +# +# (C) 1998-2021 by Stephane Peter +# +# This software is released under the terms of the GNU GPL version 2 and above +# Please read the license at http://www.gnu.org/copyleft/gpl.html +# Self-extracting archives created with this script are explicitly NOT released under the term of the GPL +# + +MS_VERSION=2.4.5 +MS_COMMAND="$0" +unset CDPATH + +for f in ${1+"$@"}; do + MS_COMMAND="$MS_COMMAND \\\\ + \\\"$f\\\"" +done + +# For Solaris systems +if test -d /usr/xpg4/bin; then + PATH=/usr/xpg4/bin:$PATH + export PATH +fi + +# Procedures + +MS_Usage() +{ + echo "Usage: $0 [args] archive_dir file_name label startup_script [script_args]" + echo "args can be one or more of the following :" + echo " --version | -v : Print out Makeself version number and exit" + echo " --help | -h : Print out this help message" + echo " --tar-quietly : Suppress verbose output from the tar command" + echo " --quiet | -q : Do not print any messages other than errors." + echo " --gzip : Compress using gzip (default if detected)" + echo " --pigz : Compress with pigz" + echo " --zstd : Compress with zstd" + echo " --bzip2 : Compress using bzip2 instead of gzip" + echo " --pbzip2 : Compress using pbzip2 instead of gzip" + echo " --xz : Compress using xz instead of gzip" + echo " --lzo : Compress using lzop instead of gzip" + echo " --lz4 : Compress using lz4 instead of gzip" + echo " --compress : Compress using the UNIX 'compress' command" + echo " --complevel lvl : Compression level for gzip pigz zstd xz lzo lz4 bzip2 and pbzip2 (default 9)" + echo " --threads thds : Number of threads to be used by compressors that support parallelization." + echo " Omit to use compressor's default. Most useful (and required) for opting" + echo " into xz's threading, usually with '--threads=0' for all available cores." + echo " pbzip2 and pigz are parallel by default, and setting this value allows" + echo " limiting the number of threads they use." + echo " --base64 : Instead of compressing, encode the data using base64" + echo " --gpg-encrypt : Instead of compressing, encrypt the data using GPG" + echo " --gpg-asymmetric-encrypt-sign" + echo " : Instead of compressing, asymmetrically encrypt and sign the data using GPG" + echo " --gpg-extra opt : Append more options to the gpg command line" + echo " --ssl-encrypt : Instead of compressing, encrypt the data using OpenSSL" + echo " --ssl-passwd pass : Use the given password to encrypt the data using OpenSSL" + echo " --ssl-pass-src src : Use the given src as the source of password to encrypt the data" + echo " using OpenSSL. See \"PASS PHRASE ARGUMENTS\" in man openssl." + echo " If this option is not supplied, the user will be asked to enter" + echo " encryption password on the current terminal." + echo " --ssl-no-md : Do not use \"-md\" option not supported by older OpenSSL." + echo " --nochown : Do not give the target folder to the current user (default)" + echo " --chown : Give the target folder to the current user recursively" + echo " --nocomp : Do not compress the data" + echo " --notemp : The archive will create archive_dir in the" + echo " current directory and uncompress in ./archive_dir" + echo " --needroot : Check that the root user is extracting the archive before proceeding" + echo " --copy : Upon extraction, the archive will first copy itself to" + echo " a temporary directory" + echo " --append : Append more files to an existing Makeself archive" + echo " The label and startup scripts will then be ignored" + echo " --target dir : Extract directly to a target directory" + echo " directory path can be either absolute or relative" + echo " --nooverwrite : Do not extract the archive if the specified target directory exists" + echo " --current : Files will be extracted to the current directory" + echo " Both --current and --target imply --notemp" + echo " --tar-format opt : Specify a tar archive format (default is ustar)" + echo " --tar-extra opt : Append more options to the tar command line" + echo " --untar-extra opt : Append more options to the during the extraction of the tar archive" + echo " --nomd5 : Don't calculate an MD5 for archive" + echo " --nocrc : Don't calculate a CRC for archive" + echo " --sha256 : Compute a SHA256 checksum for the archive" + echo " --header file : Specify location of the header script" + echo " --cleanup file : Specify a cleanup script that executes on interrupt and when finished successfully." + echo " --follow : Follow the symlinks in the archive" + echo " --noprogress : Do not show the progress during the decompression" + echo " --nox11 : Disable automatic spawn of a xterm" + echo " --nowait : Do not wait for user input after executing embedded" + echo " program from an xterm" + echo " --sign passphrase : Signature private key to sign the package with" + echo " --lsm file : LSM file describing the package" + echo " --license file : Append a license file" + echo " --help-header file : Add a header to the archive's --help output" + echo " --packaging-date date" + echo " : Use provided string as the packaging date" + echo " instead of the current date." + echo + echo " --keep-umask : Keep the umask set to shell default, rather than overriding when executing self-extracting archive." + echo " --export-conf : Export configuration variables to startup_script" + echo + echo "Do not forget to give a fully qualified startup script name" + echo "(i.e. with a ./ prefix if inside the archive)." + exit 1 +} + +# Default settings +if type gzip >/dev/null 2>&1; then + COMPRESS=gzip +elif type compress >/dev/null 2>&1; then + COMPRESS=compress +else + echo "ERROR: missing commands: gzip, compress" >&2 + MS_Usage +fi +ENCRYPT=n +PASSWD="" +PASSWD_SRC="" +OPENSSL_NO_MD=n +COMPRESS_LEVEL=9 +DEFAULT_THREADS=123456 # Sentinel value +THREADS=$DEFAULT_THREADS +KEEP=n +CURRENT=n +NOX11=n +NOWAIT=n +APPEND=n +TAR_QUIETLY=n +KEEP_UMASK=n +QUIET=n +NOPROGRESS=n +COPY=none +NEED_ROOT=n +TAR_ARGS=rvf +TAR_FORMAT=ustar +TAR_EXTRA="" +GPG_EXTRA="" +DU_ARGS=-ks +HEADER=`dirname "$0"`/makeself-header.sh +SIGNATURE="" +TARGETDIR="" +NOOVERWRITE=n +DATE=`LC_ALL=C date` +EXPORT_CONF=n +SHA256=n +OWNERSHIP=n +SIGN=n +GPG_PASSPHRASE="" + +# LSM file stuff +LSM_CMD="echo No LSM. >> \"\$archname\"" + +while true +do + case "$1" in + --version | -v) + echo Makeself version $MS_VERSION + exit 0 + ;; + --pbzip2) + COMPRESS=pbzip2 + shift + ;; + --bzip2) + COMPRESS=bzip2 + shift + ;; + --gzip) + COMPRESS=gzip + shift + ;; + --pigz) + COMPRESS=pigz + shift + ;; + --zstd) + COMPRESS=zstd + shift + ;; + --xz) + COMPRESS=xz + shift + ;; + --lzo) + COMPRESS=lzo + shift + ;; + --lz4) + COMPRESS=lz4 + shift + ;; + --compress) + COMPRESS=compress + shift + ;; + --base64) + COMPRESS=base64 + shift + ;; + --gpg-encrypt) + COMPRESS=gpg + shift + ;; + --gpg-asymmetric-encrypt-sign) + COMPRESS=gpg-asymmetric + shift + ;; + --gpg-extra) + GPG_EXTRA="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --ssl-encrypt) + ENCRYPT=openssl + shift + ;; + --ssl-passwd) + PASSWD=$2 + shift 2 || { MS_Usage; exit 1; } + ;; + --ssl-pass-src) + PASSWD_SRC=$2 + shift 2 || { MS_Usage; exit 1; } + ;; + --ssl-no-md) + OPENSSL_NO_MD=y + shift + ;; + --nocomp) + COMPRESS=none + shift + ;; + --complevel) + COMPRESS_LEVEL="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --threads) + THREADS="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --nochown) + OWNERSHIP=n + shift + ;; + --chown) + OWNERSHIP=y + shift + ;; + --notemp) + KEEP=y + shift + ;; + --copy) + COPY=copy + shift + ;; + --current) + CURRENT=y + KEEP=y + shift + ;; + --tar-format) + TAR_FORMAT="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --tar-extra) + TAR_EXTRA="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --untar-extra) + UNTAR_EXTRA="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --target) + TARGETDIR="$2" + KEEP=y + shift 2 || { MS_Usage; exit 1; } + ;; + --sign) + SIGN=y + GPG_PASSPHRASE="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --nooverwrite) + NOOVERWRITE=y + shift + ;; + --needroot) + NEED_ROOT=y + shift + ;; + --header) + HEADER="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --cleanup) + CLEANUP_SCRIPT="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --license) + # We need to escape all characters having a special meaning in double quotes + LICENSE=$(sed 's/\\/\\\\/g; s/"/\\\"/g; s/`/\\\`/g; s/\$/\\\$/g' "$2") + shift 2 || { MS_Usage; exit 1; } + ;; + --follow) + TAR_ARGS=rvhf + DU_ARGS=-ksL + shift + ;; + --noprogress) + NOPROGRESS=y + shift + ;; + --nox11) + NOX11=y + shift + ;; + --nowait) + NOWAIT=y + shift + ;; + --nomd5) + NOMD5=y + shift + ;; + --sha256) + SHA256=y + shift + ;; + --nocrc) + NOCRC=y + shift + ;; + --append) + APPEND=y + shift + ;; + --lsm) + LSM_CMD="cat \"$2\" >> \"\$archname\"" + shift 2 || { MS_Usage; exit 1; } + ;; + --packaging-date) + DATE="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --help-header) + HELPHEADER=`sed -e "s/'/'\\\\\''/g" $2` + shift 2 || { MS_Usage; exit 1; } + [ -n "$HELPHEADER" ] && HELPHEADER="$HELPHEADER +" + ;; + --tar-quietly) + TAR_QUIETLY=y + shift + ;; + --keep-umask) + KEEP_UMASK=y + shift + ;; + --export-conf) + EXPORT_CONF=y + shift + ;; + -q | --quiet) + QUIET=y + shift + ;; + -h | --help) + MS_Usage + ;; + -*) + echo Unrecognized flag : "$1" + MS_Usage + ;; + *) + break + ;; + esac +done + +if test $# -lt 1; then + MS_Usage +else + if test -d "$1"; then + archdir="$1" + else + echo "Directory $1 does not exist." >&2 + exit 1 + fi +fi +archname="$2" + +if test "$QUIET" = "y" || test "$TAR_QUIETLY" = "y"; then + if test "$TAR_ARGS" = "rvf"; then + TAR_ARGS="rf" + elif test "$TAR_ARGS" = "rvhf"; then + TAR_ARGS="rhf" + fi +fi + +if test "$APPEND" = y; then + if test $# -lt 2; then + MS_Usage + fi + + # Gather the info from the original archive + OLDENV=`sh "$archname" --dumpconf` + if test $? -ne 0; then + echo "Unable to update archive: $archname" >&2 + exit 1 + else + eval "$OLDENV" + OLDSKIP=`expr $SKIP + 1` + fi +else + if test "$KEEP" = n -a $# = 3; then + echo "ERROR: Making a temporary archive with no embedded command does not make sense!" >&2 + echo >&2 + MS_Usage + fi + # We don't want to create an absolute directory unless a target directory is defined + if test "$CURRENT" = y; then + archdirname="." + elif test x"$TARGETDIR" != x; then + archdirname="$TARGETDIR" + else + archdirname=`basename "$1"` + fi + + if test $# -lt 3; then + MS_Usage + fi + + LABEL="$3" + SCRIPT="$4" + test "x$SCRIPT" = x || shift 1 + shift 3 + SCRIPTARGS="$*" +fi + +if test "$KEEP" = n -a "$CURRENT" = y; then + echo "ERROR: It is A VERY DANGEROUS IDEA to try to combine --notemp and --current." >&2 + exit 1 +fi + +case $COMPRESS in +gzip) + GZIP_CMD="gzip -c$COMPRESS_LEVEL" + GUNZIP_CMD="gzip -cd" + ;; +pigz) + GZIP_CMD="pigz -$COMPRESS_LEVEL" + if test $THREADS -ne $DEFAULT_THREADS; then # Leave as the default if threads not indicated + GZIP_CMD="$GZIP_CMD --processes $THREADS" + fi + GUNZIP_CMD="gzip -cd" + ;; +zstd) + GZIP_CMD="zstd -$COMPRESS_LEVEL" + if test $THREADS -ne $DEFAULT_THREADS; then # Leave as the default if threads not indicated + GZIP_CMD="$GZIP_CMD --threads=$THREADS" + fi + GUNZIP_CMD="zstd -cd" + ;; +pbzip2) + GZIP_CMD="pbzip2 -c$COMPRESS_LEVEL" + if test $THREADS -ne $DEFAULT_THREADS; then # Leave as the default if threads not indicated + GZIP_CMD="$GZIP_CMD -p$THREADS" + fi + GUNZIP_CMD="bzip2 -d" + ;; +bzip2) + GZIP_CMD="bzip2 -$COMPRESS_LEVEL" + GUNZIP_CMD="bzip2 -d" + ;; +xz) + GZIP_CMD="xz -c$COMPRESS_LEVEL" + # Must opt-in by specifying a value since not all versions of xz support threads + if test $THREADS -ne $DEFAULT_THREADS; then + GZIP_CMD="$GZIP_CMD --threads=$THREADS" + fi + GUNZIP_CMD="xz -d" + ;; +lzo) + GZIP_CMD="lzop -c$COMPRESS_LEVEL" + GUNZIP_CMD="lzop -d" + ;; +lz4) + GZIP_CMD="lz4 -c$COMPRESS_LEVEL" + GUNZIP_CMD="lz4 -d" + ;; +base64) + GZIP_CMD="base64" + GUNZIP_CMD="base64 --decode -i -" + ;; +gpg) + GZIP_CMD="gpg $GPG_EXTRA -ac -z$COMPRESS_LEVEL" + GUNZIP_CMD="gpg -d" + ENCRYPT="gpg" + ;; +gpg-asymmetric) + GZIP_CMD="gpg $GPG_EXTRA -z$COMPRESS_LEVEL -es" + GUNZIP_CMD="gpg --yes -d" + ENCRYPT="gpg" + ;; +compress) + GZIP_CMD="compress -fc" + GUNZIP_CMD="(type compress >/dev/null 2>&1 && compress -fcd || gzip -cd)" + ;; +none) + GZIP_CMD="cat" + GUNZIP_CMD="cat" + ;; +esac + +if test x"$ENCRYPT" = x"openssl"; then + if test x"$APPEND" = x"y"; then + echo "Appending to existing archive is not compatible with OpenSSL encryption." >&2 + fi + + ENCRYPT_CMD="openssl enc -aes-256-cbc -salt" + DECRYPT_CMD="openssl enc -aes-256-cbc -d" + + if test x"$OPENSSL_NO_MD" != x"y"; then + ENCRYPT_CMD="$ENCRYPT_CMD -md sha256" + DECRYPT_CMD="$DECRYPT_CMD -md sha256" + fi + + if test -n "$PASSWD_SRC"; then + ENCRYPT_CMD="$ENCRYPT_CMD -pass $PASSWD_SRC" + elif test -n "$PASSWD"; then + ENCRYPT_CMD="$ENCRYPT_CMD -pass pass:$PASSWD" + fi +fi + +tmpfile="${TMPDIR:-/tmp}/mkself$$" + +if test -f "$HEADER"; then + oldarchname="$archname" + archname="$tmpfile" + # Generate a fake header to count its lines + SKIP=0 + . "$HEADER" + SKIP=`cat "$tmpfile" |wc -l` + # Get rid of any spaces + SKIP=`expr $SKIP` + rm -f "$tmpfile" + if test "$QUIET" = "n"; then + echo "Header is $SKIP lines long" >&2 + fi + archname="$oldarchname" +else + echo "Unable to open header file: $HEADER" >&2 + exit 1 +fi + +if test "$QUIET" = "n"; then + echo +fi + +if test "$APPEND" = n; then + if test -f "$archname"; then + echo "WARNING: Overwriting existing file: $archname" >&2 + fi +fi + +USIZE=`du $DU_ARGS "$archdir" | awk '{print $1}'` + +if test "." = "$archdirname"; then + if test "$KEEP" = n; then + archdirname="makeself-$$-`date +%Y%m%d%H%M%S`" + fi +fi + +test -d "$archdir" || { echo "Error: $archdir does not exist."; rm -f "$tmpfile"; exit 1; } +if test "$QUIET" = "n"; then + echo "About to compress $USIZE KB of data..." + echo "Adding files to archive named \"$archname\"..." +fi + +# See if we have GNU tar +TAR=`exec <&- 2>&-; which gtar || command -v gtar || type gtar` +test -x "$TAR" || TAR=tar + +tmparch="${TMPDIR:-/tmp}/mkself$$.tar" +( + if test "$APPEND" = "y"; then + tail -n "+$OLDSKIP" "$archname" | eval "$GUNZIP_CMD" > "$tmparch" + fi + cd "$archdir" + # "Determining if a directory is empty" + # https://www.etalabs.net/sh_tricks.html + find . \ + \( \ + ! -type d \ + -o \ + \( -links 2 -exec sh -c ' + is_empty () ( + cd "$1" + set -- .[!.]* ; test -f "$1" && return 1 + set -- ..?* ; test -f "$1" && return 1 + set -- * ; test -f "$1" && return 1 + return 0 + ) + is_empty "$0"' {} \; \ + \) \ + \) -print \ + | LC_ALL=C sort \ + | sed 's/./\\&/g' \ + | xargs $TAR $TAR_EXTRA --format $TAR_FORMAT -$TAR_ARGS "$tmparch" +) || { + echo "ERROR: failed to create temporary archive: $tmparch" + rm -f "$tmparch" "$tmpfile" + exit 1 +} + +USIZE=`du $DU_ARGS "$tmparch" | awk '{print $1}'` + +eval "$GZIP_CMD" <"$tmparch" >"$tmpfile" || { + echo "ERROR: failed to create temporary file: $tmpfile" + rm -f "$tmparch" "$tmpfile" + exit 1 +} +rm -f "$tmparch" + +if test x"$ENCRYPT" = x"openssl"; then + echo "About to encrypt archive \"$archname\"..." + { eval "$ENCRYPT_CMD -in $tmpfile -out ${tmpfile}.enc" && mv -f ${tmpfile}.enc $tmpfile; } || \ + { echo Aborting: could not encrypt temporary file: "$tmpfile".; rm -f "$tmpfile"; exit 1; } +fi + +fsize=`cat "$tmpfile" | wc -c | tr -d " "` + +# Compute the checksums + +shasum=0000000000000000000000000000000000000000000000000000000000000000 +md5sum=00000000000000000000000000000000 +crcsum=0000000000 + +if test "$NOCRC" = y; then + if test "$QUIET" = "n"; then + echo "skipping crc at user request" + fi +else + crcsum=`CMD_ENV=xpg4 cksum < "$tmpfile" | sed -e 's/ /Z/' -e 's/ /Z/' | cut -dZ -f1` + if test "$QUIET" = "n"; then + echo "CRC: $crcsum" + fi +fi + +if test "$SHA256" = y; then + SHA_PATH=`exec <&- 2>&-; which shasum || command -v shasum || type shasum` + if test -x "$SHA_PATH"; then + shasum=`eval "$SHA_PATH -a 256" < "$tmpfile" | cut -b-64` + else + SHA_PATH=`exec <&- 2>&-; which sha256sum || command -v sha256sum || type sha256sum` + shasum=`eval "$SHA_PATH" < "$tmpfile" | cut -b-64` + fi + if test "$QUIET" = "n"; then + if test -x "$SHA_PATH"; then + echo "SHA256: $shasum" + else + echo "SHA256: none, SHA command not found" + fi + fi +fi +if test "$NOMD5" = y; then + if test "$QUIET" = "n"; then + echo "Skipping md5sum at user request" + fi +else + # Try to locate a MD5 binary + OLD_PATH=$PATH + PATH=${GUESS_MD5_PATH:-"$OLD_PATH:/bin:/usr/bin:/sbin:/usr/local/ssl/bin:/usr/local/bin:/opt/openssl/bin"} + MD5_ARG="" + MD5_PATH=`exec <&- 2>&-; which md5sum || command -v md5sum || type md5sum` + test -x "$MD5_PATH" || MD5_PATH=`exec <&- 2>&-; which md5 || command -v md5 || type md5` + test -x "$MD5_PATH" || MD5_PATH=`exec <&- 2>&-; which digest || command -v digest || type digest` + PATH=$OLD_PATH + if test -x "$MD5_PATH"; then + if test `basename ${MD5_PATH}`x = digestx; then + MD5_ARG="-a md5" + fi + md5sum=`eval "$MD5_PATH $MD5_ARG" < "$tmpfile" | cut -b-32` + if test "$QUIET" = "n"; then + echo "MD5: $md5sum" + fi + else + if test "$QUIET" = "n"; then + echo "MD5: none, MD5 command not found" + fi + fi +fi +if test "$SIGN" = y; then + GPG_PATH=`exec <&- 2>&-; which gpg || command -v gpg || type gpg` + if test -x "$GPG_PATH"; then + SIGNATURE=`$GPG_PATH --pinentry-mode=loopback --batch --yes --passphrase "$GPG_PASSPHRASE" --output - --detach-sig $tmpfile | base64 | tr -d \\\\n` + if test "$QUIET" = "n"; then + echo "Signature: $SIGNATURE" + fi + else + echo "Missing gpg command" >&2 + fi +fi + +totalsize=0 +for size in $fsize; +do + totalsize=`expr $totalsize + $size` +done + +if test "$APPEND" = y; then + mv "$archname" "$archname".bak || exit + + # Prepare entry for new archive + filesizes="$fsize" + CRCsum="$crcsum" + MD5sum="$md5sum" + SHAsum="$shasum" + Signature="$SIGNATURE" + # Generate the header + . "$HEADER" + # Append the new data + cat "$tmpfile" >> "$archname" + + chmod +x "$archname" + rm -f "$archname".bak + if test "$QUIET" = "n"; then + echo "Self-extractable archive \"$archname\" successfully updated." + fi +else + filesizes="$fsize" + CRCsum="$crcsum" + MD5sum="$md5sum" + SHAsum="$shasum" + Signature="$SIGNATURE" + + # Generate the header + . "$HEADER" + + # Append the compressed tar data after the stub + if test "$QUIET" = "n"; then + echo + fi + cat "$tmpfile" >> "$archname" + chmod +x "$archname" + if test "$QUIET" = "n"; then + echo Self-extractable archive \"$archname\" successfully created. + fi +fi +rm -f "$tmpfile" diff --git a/csrc/deepep/ops2/cmake/util/makeself/run-tests.sh b/csrc/deepep/ops2/cmake/util/makeself/run-tests.sh new file mode 100755 index 00000000..31ee1651 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/makeself/run-tests.sh @@ -0,0 +1,8 @@ +#!/bin/sh +# Run every available test - Bash needed +cd test +for test in *test; +do + echo "Running test $test ..." + bash $test || { echo "*** ERROR: Test '$test' failed!"; exit 1; } +done diff --git a/csrc/deepep/ops2/cmake/util/merge_aicpu_info_json.sh b/csrc/deepep/ops2/cmake/util/merge_aicpu_info_json.sh new file mode 100755 index 00000000..970a44bf --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/merge_aicpu_info_json.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +project_path=$1 +build_path=$2 +vendor_name=customize +echo $@ +if [[ ! -d "$project_path" ]]; then + echo "[ERROR] No project path is provided" + exit 1 +fi + +if [[ ! -d "$build_path" ]]; then + echo "[ERROR] No build path is provided" + exit 1 +fi + +if [[ ! -d "$ASCEND_OPP_PATH" ]]; then + echo "[ERROR] No opp install path is provided" + exit 1 +fi +custom_exist_info_json=$ASCEND_OPP_PATH/vendors/$vendor_name/op_impl/cpu/config/cust_aicpu_kernel.json +custom_new_info_json=$build_path/makepkg/packages/vendors/$vendor_name/op_impl/cpu/config/cust_aicpu_kernel.json +temp_info_json=$build_path/makepkg/packages/vendors/$vendor_name/op_impl/cpu/config/temp_cust_aicpu_kernel.json + +if [[ -f "$custom_exist_info_json" ]] && [[ -f "$custom_new_info_json" ]]; then + cp -f $custom_exist_info_json $temp_info_json + chmod +w $temp_info_json + python3 ${project_path}/cmake/util/insert_op_info.py ${custom_new_info_json} ${temp_info_json} + cp -f $temp_info_json $custom_new_info_json + rm -f $temp_info_json +fi diff --git a/csrc/deepep/ops2/cmake/util/opdesc_parser.py b/csrc/deepep/ops2/cmake/util/opdesc_parser.py new file mode 100755 index 00000000..7b789567 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/opdesc_parser.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import os +import sys + +OP_ALL = "__ALLOP__" +SOC_ALL = "__ALLSOC__" +SOC_TO_SHORT_SOC_MAP = { + "ascend910a": "ascend910", + "ascend910proa": "ascend910", + "ascend910b": "ascend910", + "ascend910prob": "ascend910", + "ascend910premiuma": "ascend910", + "ascend910b1": "ascend910b", + "ascend910b2": "ascend910b", + "ascend910b2c": "ascend910b", + "ascend910b3": "ascend910b", + "ascend910b4": "ascend910b", + "ascend910b4-1": "ascend910b", + "ascend910_9391": "ascend910_93", + "ascend910_9381": "ascend910_93", + "ascend910_9372": "ascend910_93", + "ascend910_9392": "ascend910_93", + "ascend910_9382": "ascend910_93", + "ascend910_9361": "ascend910_93", + "ascend310p1": "ascend310p", + "ascend310p3": "ascend310p", + "ascend310p3vir01": "ascend310p", + "ascend310p3vir02": "ascend310p", + "ascend310p3vir04": "ascend310p", + "ascend310p3vir08": "ascend310p", + "ascend310b1": "ascend310b", + "bs9sx1aa": "bs9sx1a", + "ascend610lite": "ascend610lite", +} +CONFLICT_KEYWORDS = { + "and", + "as", + "assert", + "break", + "class", + "continue", + "def", + "del", + "elif", + "else", + "except", + "finally", + "for", + "from", + "global", + "if", + "import", + "in", + "is", + "lambda", + "not", + "or", + "pass", + "raise", + "return", + "try", + "while", + "with", + "yield", + "False", + "None", + "True", + "nonlocal", + "arg", + "__inputs__", + "__outputs__", + "options", + "bisheng", + "bisheng_path", + "tikcpp_path", + "impl_mode", + "custom_compile_options", + "custom_all_compile_options", + "soc_version", + "soc_short", + "custom_compile_options_soc", + "custom_all_compile_options_soc", + "origin_func_name", + "ascendc_src_dir_ex", + "ascendc_src_dir", + "ascendc_src_file", + "src", + "op_type", + "code_channel", + "op_info", + "compile_op", + "get_code_channel", + "result", + "__attrs__", + "isinstance", + "attr", + "get_current_build_config", + "_build_args", + "get_dtype_fmt_options", + "shutil", + "os", + "get_kernel_source", +} + + +class OpDesc: + def __init__(self: any, op_type: str): + self.op_type = op_type + self.attr_list = [] + self.attr_val = {} + self.input_name = [] + self.input_ori_name = [] + self.input_type = [] + self.input_dtype = [] + self.input_dtype_for_bin_list = [] + self.input_dtype_for_bin = {} + self.input_fmt = [] + self.input_fmt_for_bin_list = [] + self.input_fmt_for_bin = {} + self.input_virt = {} + self.output_name = [] + self.output_ori_name = [] + self.output_type = [] + self.output_dtype = [] + self.output_dtype_for_bin_list = [] + self.output_dtype_for_bin = {} + self.output_fmt = [] + self.output_fmt_for_bin_list = [] + self.output_fmt_for_bin = {} + self.output_init_value = [] + self.output_shape_depend_on_compute = [] + self.op_fmt_sel = False + self.op_chk_support = False + self.op_intf = "" + self.kern_name = "" + self.op_file = "" + self.op_replay_flag = False + self.op_replay_batch = False + self.input_idx = -1 + self.output_idx = -1 + self.max_block_dim = 32 + self.max_shape_size = 268435456 + self.dynamic_shape = False + self.op_range_limit = "" + self.custom_compile_options = {} + self.custom_all_compile_options = {} + self.param_type_dynamic = False + self.mc2_ctx = [] + self.bin_cprs_list = [] + self.bin_cprs_head = [] + self.bin_save_list = [] + + @staticmethod + def _parse_digit(conf: str) -> int: + return int(conf.split("=")[1]) + + @staticmethod + def _parse_flag(conf: str) -> bool: + if "true" == conf.split("=")[1]: + return True + return False + + @staticmethod + def _parse_str(conf: str) -> str: + return conf.split("=")[1] + + @staticmethod + def _parse_list(conf: str) -> list: + return conf.split("=")[1].split(",") + + def parse_input(self: any, conf: str): + if conf.startswith("input{}.name".format(int(self.input_idx) + 1)): + self.input_idx += 1 + self.input_ori_name.append(self._parse_str(conf)) + self.input_name.append(self.input_ori_name[-1] + "_in__") + elif conf.startswith("input{}.paramType".format(int(self.input_idx))): + param_type = self._parse_str(conf) + self.input_type.append(param_type) + if param_type == "dynamic": + self.param_type_dynamic = True + elif conf.startswith("input{}.dtype".format(int(self.input_idx))): + self.input_dtype.append(self._parse_str(conf)) + elif conf.startswith("input{}.for_bin_dtype".format(int(self.input_idx))): + self.input_dtype_for_bin.update({self.input_idx: self._parse_str(conf)}) + elif conf.startswith("input{}.format".format(int(self.input_idx))): + self.input_fmt.append(self._parse_str(conf)) + elif conf.startswith("input{}.for_bin_format".format(int(self.input_idx))): + self.input_fmt_for_bin.update({self.input_idx: self._parse_str(conf)}) + elif conf.startswith("input{}.virtual".format(int(self.input_idx))): + self.input_virt[self.input_idx] = self._parse_str(conf) + elif conf.startswith("input{}.initValue".format(int(self.input_idx))): + raise Exception( + f"[ERROR]: Op: {{'{self.op_type}'}} input {self.input_ori_name[int(self.input_idx)]}\ + has InitValue, which is not support!" + ) + else: + return + + def parse_output(self: any, conf: str): + if conf.startswith("output{}.name".format(int(self.output_idx) + 1)): + self.output_idx += 1 + self.output_ori_name.append(self._parse_str(conf)) + self.output_name.append(self.output_ori_name[-1] + "_out_") + self.output_init_value.append(None) + elif conf.startswith("output{}.paramType".format(int(self.output_idx))): + param_type = self._parse_str(conf) + self.output_type.append(param_type) + if param_type == "dynamic": + self.param_type_dynamic = True + elif conf.startswith("output{}.dtype".format(int(self.output_idx))): + self.output_dtype.append(self._parse_str(conf)) + elif conf.startswith("output{}.for_bin_dtype".format(int(self.output_idx))): + self.output_dtype_for_bin.update({self.output_idx: self._parse_str(conf)}) + elif conf.startswith("output{}.format".format(int(self.output_idx))): + self.output_fmt.append(self._parse_str(conf)) + elif conf.startswith("output{}.for_bin_format".format(int(self.output_idx))): + self.output_fmt_for_bin.update({self.output_idx: self._parse_str(conf)}) + elif conf.startswith("output{}.initValue".format(int(self.output_idx))): + self.output_init_value[int(self.output_idx)] = self._parse_str(conf) + elif conf.startswith( + "output{}.outputShapeDependOnCompute=true".format(int(self.output_idx)) + ): + self.output_shape_depend_on_compute.append(int(self.output_idx)) + else: + return + + def parse_op_format(self: any, conf: str): + self.op_fmt_sel = self._parse_flag(conf) + + def parse_check_support(self: any, conf: str): + self.op_chk_support = self._parse_flag(conf) + + def parse_range_limit(self: any, conf: str): + self.op_range_limit = self._parse_str(conf) + + def parse_kern_name(self: any, conf: str): + self.kern_name = self._parse_str(conf) + + def parse_op_intf(self: any, conf: str): + self.op_intf = self._parse_str(conf) + + def parse_op_file(self: any, conf: str): + self.op_file = self._parse_str(conf) + + def parse_dynamic_shape(self: any, conf: str): + self.dynamic_shape = self._parse_flag(conf) + + def parse_attr_list(self: any, conf: str): + self.attr_list = self._parse_list(conf) + intersection_element = set(self.attr_list) & CONFLICT_KEYWORDS + if intersection_element: + raise Exception( + f"[ERROR]: The attribute name: {intersection_element} in op: {{'{self.op_type}'}} \ +conflicts with the built-in variable name. Use a complex name or prefix the operator name." + ) + + def parse_mc2_ctx(self: any, conf: str): + self.mc2_ctx = self._parse_list(conf) + + @staticmethod + def _camel_to_snake(camel_case_str: str): + snake_case_str = "" + for i, c in enumerate(camel_case_str): + if i == 0: + snake_case_str += c.lower() + elif c.isupper(): + snake_case_str += "_" + c.lower() + else: + snake_case_str += c + return snake_case_str + + def parse_attr_val(self: any, conf: str): + for attr in self.attr_list: + if self.attr_val.get(attr) is None: + self.attr_val[attr] = {} + if conf.startswith("attr_{}.type".format(attr)): + self.attr_val.get(attr)["type"] = self._camel_to_snake( + self._parse_str(conf) + ) + elif conf.startswith("attr_{}.paramType".format(attr)): + self.attr_val.get(attr)["paramType"] = self._parse_str(conf) + elif conf.startswith("attr_{}.defaultValue".format(attr)): + self.attr_val.get(attr)["defaultValue"] = self._parse_str(conf) + + def parse_replay_val(self: any, batch_list: list, iterator_list: list): + if self.op_type in batch_list: + self.op_replay_flag = True + self.op_replay_batch = True + elif self.op_type in iterator_list: + self.op_replay_flag = True + self.op_replay_batch = False + + +def _is_op_type_in_opdesc(op_descs: list, op_type: str): + for op in op_descs: + if op_type == op.op_type: + return True + return False + + +def _set_all_options_to_opdescs(op_descs, soc_ver_compile_options): + for op in op_descs: + op.custom_all_compile_options = soc_ver_compile_options + + +def _set_options_to_opdesc(op_descs, op_type, soc_ver_compile_options): + for op in op_descs: + if op.op_type != op_type: + continue + op.custom_compile_options.update(soc_ver_compile_options) + + +def _trans_soc_ver_to_short(soc_ver: str): + low_soc_ver = soc_ver.lower() + if low_soc_ver not in SOC_TO_SHORT_SOC_MAP: + print( + f"WARNING: caution: {soc_ver} will trans into ascend910, if not your intention," + f"use ascend910b1~4 instead" + ) + return SOC_TO_SHORT_SOC_MAP[low_soc_ver] + + +def _get_op_custom_options(op_descs: list, auto_gen_dir: str): + if auto_gen_dir is None: + return {} + file = os.path.join(auto_gen_dir, "custom_compile_options.ini") + if not os.path.exists(file): + print(f"WARNING: cannot find {auto_gen_dir}/custom_compile_options.ini") + return {} + with open(file, "r") as fd: + lines = fd.readlines() + for line in lines: + param_list = str.split(line.rstrip("\n"), ",") + if len(param_list) != 3: + raise Exception( + f"ERROR: custom compile option {param_list} len is not 3" + ) + op_type = param_list[0] + if op_type.upper() == "ALL": + op_type = OP_ALL + if op_type != OP_ALL and _is_op_type_in_opdesc(op_descs, op_type) == False: + continue + soc_ver_compile_options = {} + soc_ver = param_list[1] + options_str = param_list[2] + options = str.split(options_str, ";") + if soc_ver == "": + soc_ver_compile_options[SOC_ALL] = options + else: + soc_ver_list = str.split(soc_ver, ";") + for ver in soc_ver_list: + short_ver = _trans_soc_ver_to_short(ver) + soc_ver_compile_options[short_ver] = options + if op_type == OP_ALL: + _set_all_options_to_opdescs(op_descs, soc_ver_compile_options) + else: + _set_options_to_opdesc(op_descs, op_type, soc_ver_compile_options) + + +def get_op_desc( + file: str, + batch_list: list, + iterator_list: list, + builder: any, + op_type: list, + auto_gen_dir: str = None, +) -> list: + op_descs = [] + op_match = False + with open(file, "r") as fd: + lines = fd.readlines() + for line in lines: + line = line.strip() + if line.startswith("["): + name = line[1:-1] + if op_type is None or name in op_type: + op_match = True + op_desc = builder(name) + op_desc.parse_replay_val(batch_list, iterator_list) + op_descs.append(op_desc) + else: + op_match = False + if op_type is not None and len(op_descs) == len(op_type): + break + continue + if not op_match: + continue + if line.startswith("input"): + op_desc.parse_input(line) + elif line.startswith("output"): + op_desc.parse_output(line) + elif line.startswith("dynamicFormat.flag"): + op_desc.parse_op_format(line) + elif line.startswith("needCheckSupport.flag"): + op_desc.parse_check_support(line) + elif line.startswith("rangeLimit.value"): + op_desc.parse_range_limit(line) + elif line.startswith("opInterface.value"): + op_desc.parse_op_intf(line) + elif line.startswith("kernel.name"): + op_desc.parse_kern_name(line) + elif line.startswith("opFile.value"): + op_desc.parse_op_file(line) + elif line.startswith("dynamicShapeSupport.flag"): + op_desc.parse_dynamic_shape(line) + elif line.startswith("mc2.ctx"): + op_desc.parse_mc2_ctx(line) + elif line.startswith("attr.list"): + op_desc.parse_attr_list(line) + elif line.startswith("attr_"): + op_desc.parse_attr_val(line) + _get_op_custom_options(op_descs, auto_gen_dir) + return op_descs diff --git a/csrc/deepep/ops2/cmake/util/parse_ini_to_json.py b/csrc/deepep/ops2/cmake/util/parse_ini_to_json.py new file mode 100755 index 00000000..928acae6 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/parse_ini_to_json.py @@ -0,0 +1,449 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +""" +parser ini to json +""" + +import json +import os +import stat +import sys + +ATTR_TYPE_LIST = [ + "int", + "float", + "bool", + "str", + "listInt", + "listFloat", + "listBool", + "listStr", + "listListInt", + "type", + "listType", + "tensor", + "listTensor", +] +ATTR_PARAMTYPE_LIST = ["optional", "required"] +BOOL_FLAG_KEY = [ + "dynamicFormat", + "dynamicShapeSupport", + "dynamicRankSupport", + "precision_reduce", + "heavyOp", + "needCheckSupport", + "enableVectorCore", +] +BOOL_LIST = ["true", "false"] +DTYPE_LIST = [ + "float16", + "float", + "float32", + "int8", + "int16", + "int32", + "uint8", + "uint16", + "uint32", + "bool", + "int64", + "uint64", + "qint8", + "qint16", + "qint32", + "quint8", + "quint16", + "double", + "complex64", + "complex128", + "string", + "resource", + "dual", + "dual_sub_int8", + "dual_sub_uint8", + "string_ref", + "int4", + "bfloat16", + "uint1", +] +FORMAT_LIST = [ + "NCHW", + "NHWC", + "ND", + "NC1HWC0", + "FRACTAL_Z", + "NC1C0HWPAD", + "NHWC1C0", + "FSR_NCHW", + "FRACTAL_DECONV", + "C1HWNC0", + "FRACTAL_DECONV_TRANSPOSE", + "FRACTAL_DECONV_SP_STRIDE_TRANS", + "NC1HWC0_C04", + "FRACTAL_Z_C04", + "CHWN", + "FRACTAL_DECONV_SP_STRIDE8_TRANS", + "HWCN", + "NC1KHKWHWC0", + "BN_WEIGHT", + "FILTER_HWCK", + "HASHTABLE_LOOKUP_LOOKUPS", + "HASHTABLE_LOOKUP_KEYS", + "HASHTABLE_LOOKUP_VALUE", + "HASHTABLE_LOOKUP_OUTPUT", + "HASHTABLE_LOOKUP_HITS", + "C1HWNCoC0", + "MD", + "NDHWC", + "FRACTAL_ZZ", + "FRACTAL_NZ", + "NCDHW", + "DHWCN", + "NDC1HWC0", + "FRACTAL_Z_3D", + "CN", + "NC", + "DHWNC", + "FRACTAL_Z_3D_TRANSPOSE", + "FRACTAL_ZN_LSTM", + "FRACTAL_ZN_RNN", + "FRACTAL_Z_G", + "NULL", +] + + +def parse_ini_files(ini_files): + """ + parse ini files to json + Parameters: + ---------------- + ini_files:input file list + return:ops_info + ---------------- + """ + tbe_ops_info = {} + for ini_file in ini_files: + check_file_size(ini_file) + parse_ini_to_obj(ini_file, tbe_ops_info) + return tbe_ops_info + + +def check_file_size(input_file): + try: + file_size = os.path.getsize(input_file) + except OSError as os_error: + print('[ERROR] Failed to open "%s". %s' % (input_file, str(os_error))) + raise OSError from os_error + if file_size > 10 * 1024 * 1024: + print( + "[WARN] The size of %s exceeds 10MB, it may take more time to run, please wait." + % input_file + ) + + +def parse_ini_to_obj(ini_file, tbe_ops_info): + """ + parse ini file to json obj + Parameters: + ---------------- + ini_file:ini file path + tbe_ops_info:ops_info + ---------------- + """ + with open(ini_file) as ini_file: + lines = ini_file.readlines() + op_dict = {} + op_name = "" + find_op_type = False + for line in lines: + line = line.rstrip() + if line == "": + continue + if line.startswith("["): + if line.endswith("]"): + op_name = line[1:-1] + op_dict = {} + tbe_ops_info[op_name] = op_dict + find_op_type = True + elif "=" in line: + key1 = line[: line.index("=")] + key2 = line[line.index("=") + 1 :] + key1_0, key1_1 = key1.split(".") + if key1_0 not in op_dict: + op_dict[key1_0] = {} + if key1_1 in op_dict.get(key1_0): + raise RuntimeError( + "Op:" + op_name + " " + key1_0 + " " + key1_1 + " is repeated!" + ) + dic_key = op_dict.get(key1_0) + dic_key[key1_1] = key2 + else: + continue + if not find_op_type: + raise RuntimeError("Not find OpType in .ini file.") + + +def check_output_exist(op_dict, is_valid): + """ + Function Description: + Check output is exist + Parameter: op_dict + Parameter: is_valid + """ + if "output0" in op_dict: + output0_dict = op_dict.get("output0") + if output0_dict.get("name", None) is None: + is_valid = False + print("output0.name is required in .ini file!") + else: + is_valid = False + print("output0 is required in .ini file!") + return is_valid + + +def check_attr_dict(attr_dict, is_valid, attr): + """ + Function Description: + Check attr_dict + Parameter: attr_dict + Parameter: is_valid + Parameter: attr + """ + attr_type = attr_dict.get("type") + value = attr_dict.get("value") + param_type = attr_dict.get("paramType") + if attr_type is None or value is None: + is_valid = False + print("If attr.list is exist, {0}.type and {0}.value is required".format(attr)) + if param_type and param_type not in ATTR_PARAMTYPE_LIST: + is_valid = False + print("{0}.paramType only support {1}.".format(attr, ATTR_PARAMTYPE_LIST)) + if attr_type and attr_type not in ATTR_TYPE_LIST: + is_valid = False + print("{0}.type only support {1}.".format(attr, ATTR_TYPE_LIST)) + return is_valid + + +def check_attr(op_dict, is_valid): + """ + Function Description: + Check attr + Parameter: op_dict + Parameter: is_valid + """ + if "attr" in op_dict: + attr_dict = op_dict.get("attr") + attr_list_str = attr_dict.get("list", None) + if attr_list_str is None: + is_valid = False + print("attr.list is required in .ini file!") + else: + attr_list = attr_list_str.split(",") + for attr_name in attr_list: + attr = "attr_" + attr_name.strip() + attr_dict = op_dict.get(attr) + if attr_dict: + is_valid = check_attr_dict(attr_dict, is_valid, attr) + else: + is_valid = False + print( + "%s is required in .ini file, when attr.list is %s!" + % (attr, attr_list_str) + ) + return is_valid + + +def check_bool_flag(op_dict, is_valid): + """ + Function Description: + check_bool_flag + Parameter: op_dict + Parameter: is_valid + """ + for key in BOOL_FLAG_KEY: + if key in op_dict: + op_bool_key = op_dict.get(key) + if op_bool_key.get("flag").strip() not in BOOL_LIST: + is_valid = False + print("{0}.flag only support {1}.".format(key, BOOL_LIST)) + return is_valid + + +def check_type_format(op_info, is_valid, op_info_key): + """ + Function Description: + Check type and format + Parameter: op_info + Parameter: is_valid + Parameter: op_info_key + """ + op_info_dtype_str = op_info.get("dtype") + op_info_dtype_num = 0 + op_info_format_num = 0 + if op_info_dtype_str: + op_info_dtype = op_info_dtype_str.split(",") + op_info_dtype_num = len(op_info_dtype) + for dtype in op_info_dtype: + if dtype.strip() not in DTYPE_LIST: + is_valid = False + print("{0}.dtype not support {1}.".format(op_info_key, dtype)) + op_info_format_str = op_info.get("format") + if op_info_format_str: + op_info_format = op_info_format_str.split(",") + op_info_format_num = len(op_info_format) + for op_format in op_info_format: + if op_format.strip() not in FORMAT_LIST: + is_valid = False + print("{0}.format not support {1}.".format(op_info_key, op_format)) + if op_info_dtype_num > 0 and op_info_format_num > 0: + if op_info_dtype_num != op_info_format_num: + is_valid = False + print( + "The number of {0}.dtype not match the number of {0}.format.".format( + op_info_key + ) + ) + return is_valid + + +def check_op_info(tbe_ops): + """ + Function Description: + Check info. + Parameter: tbe_ops + Return Value: is_valid + """ + print("\n\n==============check valid for ops info start==============") + required_op_input_info_keys = ["paramType", "name"] + required_op_output_info_keys = ["paramType", "name"] + param_type_valid_value = ["dynamic", "optional", "required"] + is_valid = True + for op_key in tbe_ops: + op_dict = tbe_ops[op_key] + for op_info_key in op_dict: + if op_info_key.startswith("input"): + op_input_info = op_dict[op_info_key] + missing_keys = [] + for required_op_input_info_key in required_op_input_info_keys: + if required_op_input_info_key not in op_input_info: + missing_keys.append(required_op_input_info_key) + if len(missing_keys) > 0: + print( + "op: " + + op_key + + " " + + op_info_key + + " missing: " + + ",".join(missing_keys) + ) + is_valid = False + else: + if not op_input_info["paramType"] in param_type_valid_value: + print( + "op: " + + op_key + + " " + + op_info_key + + " paramType not valid, valid key:[dynamic, " + "optional, required]" + ) + is_valid = False + is_valid = check_type_format(op_input_info, is_valid, op_info_key) + if op_info_key.startswith("output"): + op_input_info = op_dict[op_info_key] + missing_keys = [] + for required_op_input_info_key in required_op_output_info_keys: + if required_op_input_info_key not in op_input_info: + missing_keys.append(required_op_input_info_key) + if len(missing_keys) > 0: + print( + "op: " + + op_key + + " " + + op_info_key + + " missing: " + + ",".join(missing_keys) + ) + is_valid = False + else: + if not op_input_info["paramType"] in param_type_valid_value: + print( + "op: " + + op_key + + " " + + op_info_key + + " paramType not valid, valid key:[dynamic, " + "optional, required]" + ) + is_valid = False + is_valid = check_type_format(op_input_info, is_valid, op_info_key) + is_valid = check_attr(op_dict, is_valid) + is_valid = check_bool_flag(op_dict, is_valid) + print("==============check valid for ops info end================\n\n") + return is_valid + + +def write_json_file(tbe_ops_info, json_file_path): + """ + Save info to json file + Parameters: + ---------------- + tbe_ops_info: ops_info + json_file_path: json file path + ---------------- + """ + json_file_real_path = os.path.realpath(json_file_path) + wr_flag = os.O_WRONLY | os.O_CREAT + wr_mode = stat.S_IWUSR | stat.S_IRUSR + with os.fdopen(os.open(json_file_real_path, wr_flag, wr_mode), "w") as file_path: + # The owner have all rights£¬group only have read rights + os.chmod(json_file_real_path, stat.S_IWUSR + stat.S_IRGRP + stat.S_IRUSR) + json.dump( + tbe_ops_info, file_path, sort_keys=True, indent=4, separators=(",", ":") + ) + print("Compile op info cfg successfully.") + + +def parse_ini_to_json(ini_file_paths, outfile_path): + """ + parse ini files to json file + Parameters: + ---------------- + ini_file_paths: list of ini file path + outfile_path: output file path + ---------------- + """ + tbe_ops_info = parse_ini_files(ini_file_paths) + if not check_op_info(tbe_ops_info): + print("Compile op info cfg failed.") + return False + write_json_file(tbe_ops_info, outfile_path) + return True + + +if __name__ == "__main__": + args = sys.argv + + OUTPUT_FILE_PATH = "tbe_ops_info.json" + ini_file_path_list = [] + parse_ini_list = [] + + for arg in args: + if arg.endswith("ini"): + ini_file_path_list.append(arg) + OUTPUT_FILE_PATH = arg.replace(".ini", ".json") + if arg.endswith("json"): + OUTPUT_FILE_PATH = arg + + if not ini_file_path_list: + ini_file_path_list.append("tbe_ops_info.ini") + + for ini_file in ini_file_path_list: + if os.path.exists(ini_file): + parse_ini_list.append(ini_file) + + if parse_ini_list: + if not parse_ini_to_json(parse_ini_list, OUTPUT_FILE_PATH): + sys.exit(1) + sys.exit(0) diff --git a/csrc/deepep/ops2/cmake/util/preset_parse.py b/csrc/deepep/ops2/cmake/util/preset_parse.py new file mode 100755 index 00000000..983f16e7 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/preset_parse.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import json +import os +import sys + + +def read_json(file): + with open(file, "r") as fd: + config = json.load(fd) + return config + + +def get_config_opts(file): + config = read_json(file) + + src_dir = os.path.abspath(os.path.dirname(file)) + opts = "" + + for conf in config: + if conf == "configurePresets": + for node in config[conf]: + macros = node.get("cacheVariables") + if macros is not None: + for key in macros: + opts += "-D{}={} ".format(key, macros[key]["value"]) + + opts = opts.replace("${sourceDir}", src_dir) + print(opts) + + +if __name__ == "__main__": + get_config_opts(sys.argv[1]) diff --git a/csrc/deepep/ops2/cmake/util/replay_codegen.py b/csrc/deepep/ops2/cmake/util/replay_codegen.py new file mode 100755 index 00000000..6f896a09 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/replay_codegen.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import collections +import os +import stat + +import code_channel_infer +import const_var +import kernel_entry as keb +from tiling_data_def_build import gen_tiling + +PYF_PATH = os.path.dirname(__file__) + +ReplayCodeGenParams = collections.namedtuple( + "ReplayCodeGenParams", + [ + "op_type", + "impl", + "tiling_file", + "kernel", + "entry", + "argn", + "op_replay_batch", + "max_block_dim", + "max_shape_size", + ], +) + + +class ReplayCodeGen: + def __init__(self, replayCodeGenParams): + self.op_type = replayCodeGenParams.op_type + self.impl = replayCodeGenParams.impl + self.tiling_file = replayCodeGenParams.tiling_file + self.tiling_data_file = "" + self.kernel = replayCodeGenParams.kernel + self.entry = replayCodeGenParams.entry + self.argn = replayCodeGenParams.argn + self.batch = False + self.outdir = "" + self.data_type = "uint8_t" + self.blknum = 32 + self.op_replay_batch = replayCodeGenParams.op_replay_batch + self.max_block_dim = replayCodeGenParams.max_block_dim + self.max_shape_size = replayCodeGenParams.max_shape_size + + def set_batch(self, is_batch): + self.batch = is_batch + + def set_outdir(self, outdir): + self.outdir = outdir + + def gen_replay(self, ops_product: str): + kerentry = os.path.join(self.outdir, self.kernel + "_entry.cce") + kerimpl = os.path.join(self.outdir, self.kernel + "_impl.cpp") + replayimpl = os.path.join(self.outdir, self.kernel + "_replay.cpp") + if self.batch: + reptmp = os.path.join(PYF_PATH, "batch_replay_impl.temp") + else: + reptmp = os.path.join(PYF_PATH, "replay_impl.temp") + kertmp = os.path.join(PYF_PATH, "kernel_impl.temp") + self._gen_kentry(kerentry) + self._gen_kimpl_code(kerimpl, kertmp) + self._gen_tiling_data_header() + self._gen_replay_code(replayimpl, reptmp, ops_product) + + def _gen_tiling_data_header(self): + self.tiling_data_file = os.path.join( + self.outdir, self.kernel + "_tiling_data.h" + ) + gen_tiling(self.tiling_file, self.tiling_data_file) + + def _gen_kimpl_code(self, src, tmpfile): + with open(tmpfile, "r") as fd: + temp = fd.read() + temp = temp.replace("__CCE_FILE__", self.impl) + with os.fdopen(os.open(src, const_var.WFLAGS, const_var.WMODES), "w") as ofd: + ofd.write(temp) + + def _gen_replay_code(self, src, tmpfile, ops_product: str): + with open(tmpfile, "r") as fd: + temp = fd.read() + temp = temp.replace("__ARG_NUM__", str(self.argn)) + argdef = [] + kargs = [] + for i in range(0, self.argn): + argdef.append("{} *".format(self.data_type)) + kargs.append("({} *)GetArg({})".format(self.data_type, i)) + temp = temp.replace("__ARGS_DEF__", ", ".join(argdef)) + temp = temp.replace("__KERNEL_ARGS__", ", ".join(kargs)) + temp = temp.replace("__KERNEL_FUN__", self.entry) + core_type_infer = "core_type" + code_channel = code_channel_infer.infer_code_channel( + code_channel_infer.InfoCodeChanelParams( + self.impl, + self.tiling_data_file, + self.kernel, + self.outdir, + ops_product, + None, + ) + ) + if code_channel == code_channel_infer.CODE_VEC: + core_type_infer = "0" + elif code_channel == code_channel_infer.CODE_CUBE: + core_type_infer = "1" + temp = temp.replace("__CORE_TYPE__", core_type_infer) + # register function + temp = temp.replace("__OPS_PRODUCT__", ops_product) + temp = temp.replace("__OPTYPE__", self.op_type) + with os.fdopen(os.open(src, const_var.WFLAGS, const_var.WMODES), "w") as ofd: + ofd.write(temp) + + def _gen_kentry(self, src): + kf = "" + pre_alloc_str = "A" * 256 + if self.batch: + kf += keb.batch_code_gen( + "K{:02d}_{}{}".format(0, self.entry, pre_alloc_str), + self.argn, + self.data_type, + ) + else: + kf += keb.mc_code_gen( + "K{:02d}_{}{}".format(0, self.entry, pre_alloc_str), + self.argn, + self.data_type, + self.blknum, + ) + with os.fdopen(os.open(src, const_var.WFLAGS, const_var.WMODES), "w") as ofd: + ofd.write(kf) diff --git a/csrc/deepep/ops2/cmake/util/replay_impl.temp b/csrc/deepep/ops2/cmake/util/replay_impl.temp new file mode 100755 index 00000000..1d30dd86 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/replay_impl.temp @@ -0,0 +1,120 @@ +#include +#include +#include +#include +#include +#include +#include "replay_def.h" +#include "code_gen.h" +#include "replay_fun.h" +#include "register/op_check.h" +#define __ASCENDC_REPLAY_CODE__ +using namespace std; +using namespace optiling; +using namespace AscendCReplay; + +extern "C" void __KERNEL_FUN__ (__ARGS_DEF__, const char *); +extern "C" int elf_append(char *elf, uint32_t elfSize, char *jit, int kernum, int blknum[], char *atext[], + int alen[], int atlen, const char* kernelname[]); + +#define KERNEL_N 1 +#define ARG_N (__ARG_NUM__) +#define MAX_L (1024 * 1024 * 100) +#define MAX_E (1024 * 1024) + +int __KERNEL_FUN___replay___OPS_PRODUCT__(ReplayFuncParam& param, const int core_type) +{ + // gen type 1 : direct call codes 0: load .o file + if (param.gentype < 0 || param.gentype > 1) { + printf("Error: call replay gen type is %d, should only be 1 or 0\n", param.gentype); + return 0; + } else if (param.gentype == 1 && param.objptr == nullptr) { + printf("Error: call replay with direct call mode, but code obj addr is null\n"); + return 0; + } else if (param.gentype == 0 && param.output_kernel_file == nullptr) { + printf("Error: call replay with object file mode, but object file path is null\n"); + return 0; + } + // core_type 0:MIX 1:CUBE 2:VEC + if (core_type < 0 || core_type > 2) { + printf("Error: call replay core type is %d !\n", core_type); + return 0; + } + g_coreType = __CORE_TYPE__; + g_taskRation = param.task_ration; + g_tilingKey = param.tiling_key; + + unsigned char *buf, *jit; + char *kernel[KERNEL_N * 32]; + int len[KERNEL_N * 32]; + int blknum[KERNEL_N]; + int max; + block_num = param.block_dim; + g_ubBase = block_num; + uint8_t *code = (uint8_t *)malloc(MAX_L); + uint8_t *pos = code; + struct timespec tp1, tp2; + + clock_gettime(CLOCK_MONOTONIC, &tp1); + if (block_num > 32) { + printf("Error: block_num > 32\n"); + return 0; + } + //__OP_FOPEN__ + for (int i = 0; i < KERNEL_N; i++) { + for (int j = 0; j < ARG_N; j++) + AddArg(j, ARG_STEP * (j + 1)); + for (block_idx = 0; block_idx < block_num; block_idx++) { + //__OP_SET_KERNEL__ + int code_idx = i * block_num + block_idx; +#ifdef FP_CEILING + SetCtrlFloatEnable(); +#else + SetCtrlFloatDisable(); +#endif + CodeInit(pos, false); + __KERNEL_FUN__(__KERNEL_ARGS__, param.tiling_data); + CodeEnd(); + kernel[code_idx] = (char *)pos; + len[code_idx] = CodeLen(); + pos += len[code_idx]; + printf("kernel %d core %ld code generated len %d\n", i, block_idx, len[code_idx]); + } + blknum[i] = block_num; + } + //__OP_FCLOSE__ + clock_gettime(CLOCK_MONOTONIC, &tp2); + buf = (unsigned char *)malloc(MAX_E); + int fd = open(param.entry_file, O_RDONLY); + if (fd < 0) { + printf("[error]: cannot find entry.o : %s\n", param.entry_file); + return 0; + } + uint32_t bufSize = read(fd, buf, MAX_E); + if (bufSize <= 0) { + printf("[error]: entry.o : %s is too small ! \n", param.entry_file); + } + close(fd); + jit = (unsigned char *)malloc(MAX_L); + printf("total code generated %ld\n", pos - code); + int sz = elf_append((char *)buf, bufSize, (char *)jit, KERNEL_N, blknum, kernel, len, pos - code, ¶m.kernel_name); + if (tp1.tv_sec != tp2.tv_sec) { + printf("%ld NS\n", tp2.tv_nsec + 1000000000 - tp1.tv_nsec); + } else { + printf("%ld NS\n", tp2.tv_nsec - tp1.tv_nsec); + } + printf("new elf size %d\n", sz); + if (param.gentype == 0) { + fd = open(param.output_kernel_file, O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR); + (void)write(fd, jit, sz); + close(fd); + free(jit); + } else if (param.gentype == 1) { + *param.objptr = (char*)jit; + } + free(buf); + free(code); + return sz; +} + +REG_REPLAY_FUNC(__OPTYPE__, __OPS_PRODUCT__, __KERNEL_FUN___replay___OPS_PRODUCT__); diff --git a/csrc/deepep/ops2/cmake/util/tiling_data_def_build.py b/csrc/deepep/ops2/cmake/util/tiling_data_def_build.py new file mode 100755 index 00000000..0576b202 --- /dev/null +++ b/csrc/deepep/ops2/cmake/util/tiling_data_def_build.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +Function: +The replay function entry +""" + +import os +import re +import stat +import sys + +import const_var + + +def gen_tiling(tiling_header_file: str, tiling_file_out: str): + if not os.path.exists(tiling_header_file): + print("warning: no userdef tiling header file: ", tiling_header_file) + return + print("generate tiling def header file: ", tiling_file_out) + tmp_name = os.path.splitext(os.path.basename(tiling_header_file))[0].upper() + tiling_source = "#ifndef __{}_H__\n".format(tmp_name) + tiling_source += "#define __{}_H__\n\n".format(tmp_name) + tiling_source += "#include \n" + tiling_source += "#include \n\n" + tiling_source += '#include "kernel_tiling/kernel_tiling.h"\n\n' + end_source = "" + pattern = re.compile(r"[(](.*)[)]", re.S) + with open(tiling_header_file, "r") as fd: + lines = fd.readlines() + for line in lines: + line = line.strip() + if line.startswith("BEGIN_TILING_DATA_DEF"): + tiling_source += "#pragma pack(1)\n" + tiling_source += "struct " + struct_def = re.findall(pattern, line)[0] + tiling_source += struct_def + " {\n" + elif line.startswith("TILING_DATA_FIELD_DEF_ARR"): + field_params = re.findall(pattern, line)[0] + fds = field_params.split(",") + tiling_source += " {} {}[{}] = {{}};\n".format( + fds[0].strip(), fds[2].strip(), fds[1].strip() + ) + elif line.startswith("TILING_DATA_FIELD_DEF_STRUCT"): + field_params = re.findall(pattern, line)[0] + fds = field_params.split(",") + tiling_source += " {} {};\n".format(fds[0].strip(), fds[1].strip()) + elif line.startswith("TILING_DATA_FIELD_DEF"): + field_params = re.findall(pattern, line)[0] + fds = field_params.split(",") + tiling_source += " {} {} = 0;\n".format( + fds[0].strip(), fds[1].strip() + ) + elif line.startswith("END_TILING_DATA_DEF"): + tiling_source += "};\n" + tiling_source += "#pragma pack()\n\n" + tiling_source += "#ifdef __NPU_TILING__\n" + tiling_source += "inline [aicore] void Init{stru}(const __gm__ uint8_t* tiling, {stru}* const_data)\n".format( + stru=struct_def + ) + tiling_source += "{\n" + tiling_source += " const __gm__ uint32_t *src = (const __gm__ uint32_t *)tiling;\n" + tiling_source += " uint32_t *dst = (uint32_t *)const_data;\n" + tiling_source += " for (auto i = 0; i < sizeof({}) / 4; i++) *(dst + i) = *(src + i);\n".format( + struct_def + ) + tiling_source += "}\n" + tiling_source += "#else\n" + tiling_source += "inline void Init{stru}(uint8_t* tiling, {stru}* const_data)\n".format( + stru=struct_def + ) + tiling_source += "{\n" + tiling_source += " uint64_t *src = (uint64_t *)tiling;\n" + tiling_source += " uint64_t *dst = (uint64_t *)const_data;\n" + tiling_source += " for (auto i = 0; i < sizeof({}) / 8; i++) *(dst + i) = *(src + i);\n".format( + struct_def + ) + tiling_source += "}\n" + tiling_source += "#endif\n\n" + end_source = """ +#undef GET_TILING_DATA +#define GET_TILING_DATA(tiling_data, tiling_arg) \\ +{stru} tiling_data; \\ +Init{stru}(tiling_arg, &tiling_data)\n +""".format( + stru=struct_def + ) + tiling_source += end_source + tiling_source += "#endif" + with os.fdopen( + os.open(tiling_file_out, const_var.WFLAGS, const_var.WMODES), "w" + ) as ofd: + ofd.write(tiling_source) + + +if __name__ == "__main__": + if len(sys.argv) <= 2: + raise RuntimeError("arguments must greater than 2") + gen_tiling(sys.argv[1], sys.argv[2]) diff --git a/csrc/deepep/ops2/op_host/CMakeLists.txt b/csrc/deepep/ops2/op_host/CMakeLists.txt new file mode 100644 index 00000000..f40147f5 --- /dev/null +++ b/csrc/deepep/ops2/op_host/CMakeLists.txt @@ -0,0 +1,174 @@ + +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} ops_srcs) + +opbuild(OPS_SRC ${ops_srcs} + OUT_DIR ${ASCEND_AUTOGEN_PATH} +) + +file(GLOB group_proto_src ${ASCEND_AUTOGEN_PATH}/group_proto/*.cc) + +add_library(cust_op_proto SHARED + $<$:${group_proto_src}> + ${ops_srcs} + ${ASCEND_AUTOGEN_PATH}/op_proto.cc +) +target_compile_definitions(cust_op_proto PRIVATE OP_PROTO_LIB) +target_compile_options(cust_op_proto PRIVATE + -fvisibility=hidden +) +if(ENABLE_CROSS_COMPILE) + target_link_directories(cust_op_proto PRIVATE + ${CMAKE_COMPILE_COMPILER_LIBRARY} + ${CMAKE_COMPILE_RUNTIME_LIBRARY} + ) +endif() +target_link_libraries(cust_op_proto PRIVATE + intf_pub + exe_graph + register + tiling_api + -Wl,--whole-archive + rt2_registry + -Wl,--no-whole-archive +) +set_target_properties(cust_op_proto PROPERTIES OUTPUT_NAME + cust_opsproto_rt2.0 +) +file(GLOB fallback_src ${ASCEND_AUTOGEN_PATH}/fallback_*.cpp) +add_library(cust_optiling SHARED ${ops_srcs}) +if (${fallback_src}) + target_sources(cust_optiling PRIVATE ${fallback_src}) +endif() +target_compile_definitions(cust_optiling PRIVATE OP_TILING_LIB) +target_compile_options(cust_optiling PRIVATE + -fvisibility=hidden +) +if(ENABLE_CROSS_COMPILE) + target_link_directories(cust_optiling PRIVATE + ${CMAKE_COMPILE_COMPILER_LIBRARY} + ${CMAKE_COMPILE_RUNTIME_LIBRARY} + ) +endif() +target_link_libraries(cust_optiling PRIVATE + nnopbase + intf_pub + exe_graph + register + tiling_api + -Wl,--whole-archive + rt2_registry + -Wl,--no-whole-archive +) +set_target_properties(cust_optiling PROPERTIES OUTPUT_NAME + cust_opmaster_rt2.0 +) + +file(GLOB_RECURSE pregen_file + "${CMAKE_CURRENT_SOURCE_DIR}/op_api/*" +) + +file(COPY ${pregen_file} DESTINATION ${ASCEND_AUTOGEN_PATH}) +file(GLOB aclnn_src ${ASCEND_AUTOGEN_PATH}/aclnn*.cpp) +file(GLOB aclnn_inc ${ASCEND_AUTOGEN_PATH}/aclnn_*.h) +if(NOT ASCEND_PACK_SHARED_LIBRARY) + add_library(cust_opapi SHARED ${aclnn_src}) +else() + file(GLOB op_registry ${ASCEND_AUTOGEN_PATH}/custom_op_registry.cpp) + add_library(cust_opapi SHARED ${aclnn_src} ${op_registry}) + target_compile_definitions(cust_opapi PRIVATE ACLNN_WITH_BINARY) +endif() +if(ENABLE_CROSS_COMPILE) + target_link_directories(cust_opapi PRIVATE + ${CMAKE_COMPILE_COMPILER_LIBRARY} + ${CMAKE_COMPILE_RUNTIME_LIBRARY} + ) +endif() +if(NOT ASCEND_PACK_SHARED_LIBRARY) + target_link_libraries(cust_opapi PRIVATE intf_pub ascendcl nnopbase) +else() + add_library(cust_op_proto_obj OBJECT + $<$:${group_proto_src}> + ${ops_srcs} + ${ASCEND_AUTOGEN_PATH}/op_proto.cc + ) + target_compile_definitions(cust_op_proto_obj PRIVATE OP_PROTO_LIB) + target_compile_options(cust_op_proto_obj PRIVATE + -fvisibility=hidden + ) + if(ENABLE_CROSS_COMPILE) + target_link_directories(cust_op_proto_obj PRIVATE + ${CMAKE_COMPILE_COMPILER_LIBRARY} + ${CMAKE_COMPILE_RUNTIME_LIBRARY} + ) + endif() + target_link_libraries(cust_op_proto_obj PRIVATE + intf_pub + exe_graph + register + tiling_api + -Wl,--whole-archive + rt2_registry + -Wl,--no-whole-archive + ) + add_library(cust_optiling_obj OBJECT ${ops_srcs}) + target_compile_definitions(cust_optiling_obj PRIVATE OP_TILING_LIB) + target_compile_options(cust_optiling_obj PRIVATE + -fvisibility=hidden + ) + if(ENABLE_CROSS_COMPILE) + target_link_directories(cust_optiling_obj PRIVATE + ${CMAKE_COMPILE_COMPILER_LIBRARY} + ${CMAKE_COMPILE_RUNTIME_LIBRARY} + ) + endif() + target_link_libraries(cust_optiling_obj PRIVATE + intf_pub + exe_graph + register + tiling_api + -Wl,--whole-archive + rt2_registry + -Wl,--no-whole-archive + ) + target_compile_options(cust_opapi PRIVATE -DLOG_CPP) + target_include_directories(cust_opapi INTERFACE ${CMAKE_SOURCE_DIR}/build_out/library/) + target_link_libraries(cust_opapi PRIVATE intf_pub ascendcl nnopbase cust_optiling_obj cust_op_proto_obj ascend_opregistry ascend_kernels) + add_dependencies(cust_opapi ascend_opregistry) +endif() + +target_include_directories(cust_opapi PRIVATE $ENV{ASCEND_HOME_PATH}/${CANN_HOST_ARCH}-linux/include/experiment/platform/) +include_directories($ENV{ASCEND_HOME_PATH}/../opp/vendors/CAM/op_impl/ai_core/tbe/CAM_impl/dynamic/) + +add_custom_target(optiling_compat ALL + COMMAND ln -sf lib/linux/${CMAKE_SYSTEM_PROCESSOR}/$ + ${CMAKE_CURRENT_BINARY_DIR}/liboptiling.so +) +if(NOT ASCEND_PACK_SHARED_LIBRARY) + install(TARGETS cust_op_proto + LIBRARY DESTINATION packages/vendors/${vendor_name}/op_proto/lib/linux/${CMAKE_SYSTEM_PROCESSOR}) + install(FILES ${ASCEND_AUTOGEN_PATH}/op_proto.h + DESTINATION packages/vendors/${vendor_name}/op_proto/inc) + file(GLOB GROUP_PROTO_HEADERS ${ASCEND_AUTOGEN_PATH}/group_proto/*.h) + if (GROUP_PROTO_HEADERS) + install(FILES ${GROUP_PROTO_HEADERS} + DESTINATION packages/vendors/${vendor_name}/op_proto/inc) + endif() + install(TARGETS cust_optiling + LIBRARY DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/op_tiling/lib/linux/${CMAKE_SYSTEM_PROCESSOR}) + install(FILES ${CMAKE_CURRENT_BINARY_DIR}/liboptiling.so + DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/op_tiling) + install(TARGETS cust_opapi + LIBRARY DESTINATION packages/vendors/${vendor_name}/op_api/lib) + install(FILES ${aclnn_inc} + DESTINATION packages/vendors/${vendor_name}/op_api/include) +else() + file(GLOB group_inc ${ASCEND_AUTOGEN_PATH}/group_proto/*.h) + install(TARGETS cust_opapi + LIBRARY DESTINATION op_api/lib) + install(FILES ${ASCEND_AUTOGEN_PATH}/op_proto.h + DESTINATION op_api/include) + install(FILES ${group_inc} + DESTINATION op_api/include) + install(FILES ${aclnn_inc} + DESTINATION op_api/include) +endif() diff --git a/csrc/deepep/ops2/op_host/dispatch_layout.cpp b/csrc/deepep/ops2/op_host/dispatch_layout.cpp new file mode 100644 index 00000000..48de2e77 --- /dev/null +++ b/csrc/deepep/ops2/op_host/dispatch_layout.cpp @@ -0,0 +1,70 @@ +#include "register/op_def_registry.h" + +namespace ops { +class DispatchLayout : public OpDef +{ +public: + explicit DispatchLayout(const char *name) : OpDef(name) + { + this->Input("topkIdx") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Attr("num_tokens").Int(); + this->Attr("num_ranks").Int(); + this->Attr("num_experts").Int(); + this->Attr("num_topk").Int(); + this->Attr("local_ranksize").Int(); + + this->Output("numTokensPerRank") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("numTokensPerExpert") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("isTokenInRank") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("totalData") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + OpAICoreConfig a3_config; + a3_config.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true) + .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") + .ExtendCfgInfo("jitCompile.flag", "static_true") + .ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel"); + + OpAICoreConfig a2_config; + a2_config.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true) + .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") + .ExtendCfgInfo("jitCompile.flag", "static_false") + .ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel"); + + this->AICore().AddConfig("ascend910_93", a3_config); + this->AICore().AddConfig("ascend910b", a2_config); + } +}; + +OP_ADD(DispatchLayout); +} // namespace ops diff --git a/csrc/deepep/ops2/op_host/dispatch_layout_tiling.cc b/csrc/deepep/ops2/op_host/dispatch_layout_tiling.cc new file mode 100644 index 00000000..f566fa22 --- /dev/null +++ b/csrc/deepep/ops2/op_host/dispatch_layout_tiling.cc @@ -0,0 +1,252 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "error_log.h" +#include "graph/utils/type_utils.h" +#include "register/op_def_registry.h" +#include "../op_kernel/dispatch_layout_tiling.h" +#include "tiling/platform/platform_ascendc.h" +#include "tiling/hccl/hccl_tiling.h" +#include "experiment/platform/platform/platform_infos_def.h" + +using namespace ge; +namespace { +constexpr uint32_t INPUT_TOPK_IDX_INDEX = 0; + +constexpr uint32_t OUTPUT_NUM_TOKEN_PER_RANK_INDEX = 0; +constexpr uint32_t OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX = 1; +constexpr uint32_t OUTPUT_IS_TOKEN_IN_RANK_INDEX = 2; +constexpr uint32_t OUTPUT_TOTAL_DATA_INDEX = 3; + +constexpr uint32_t ATTR_NUM_TOKENS_INDEX = 0; +constexpr uint32_t ATTR_NUM_RANKS_INDEX = 1; +constexpr uint32_t ATTR_NUM_EXPERTS_INDEX = 2; +constexpr uint32_t ATTR_NUM_TOPK_INDEX = 3; +constexpr uint32_t ATTR_LOCAL_RANKSIZE_INDEX = 4; +const int64_t MAX_COMM_WORLD_SIZE = 384; +const int64_t MAX_MOE_EXPERTS_NUM = 512; +const int64_t MAX_A2_LOCAL_RANKSIZE = 8; +constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024; +constexpr uint32_t KERNEL_USE_WORKSPACE = 1 * 1024 * 1024; +constexpr uint32_t KERNEL_A2_ARG_SIZE = 1 * 1024 * 1024; + +constexpr static int TILING_KEY_INT = 23; +constexpr static int TILING_KEY_A2_TYPE = 100; + +constexpr uint32_t TWO_DIMS = 2; +constexpr uint32_t K_MAX = 16; +} // namespace + +namespace optiling { +static void PrintTilingDataInfo(const char *nodeName, DispatchLayoutTilingData &tilingData) +{ + OP_LOGD(nodeName, "numToken is %u.", tilingData.dispatchLayoutInfo.numTokens); + OP_LOGD(nodeName, "numRanks is %u.", tilingData.dispatchLayoutInfo.numRanks); + OP_LOGD(nodeName, "numExperts is %u.", tilingData.dispatchLayoutInfo.numExperts); + OP_LOGD(nodeName, "numTopk is %u.", tilingData.dispatchLayoutInfo.numTopk); + OP_LOGD(nodeName, "localRankSize is %u.", tilingData.dispatchLayoutInfo.localRankSize); + OP_LOGD(nodeName, "totalUbSize is %lu.", tilingData.dispatchLayoutInfo.totalUbSize); +} + +static bool CheckIfA2Machine(gert::TilingContext *context) +{ + fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo(); + fe::PlatFormInfos &platformInfo = *platformInfoPtr; + + std::string socVersion; + (void)platformInfo.GetPlatformResWithLock("version", "Short_SoC_version", socVersion); + + if (socVersion == "Ascend910B") { + return true; + } + return false; +} + +static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName, + DispatchLayoutTilingData &tilingData) +{ + auto attrs = context->GetAttrs(); + OP_TILING_CHECK(attrs == nullptr, OP_LOGE(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED); + + auto numTokensPtr = attrs->GetAttrPointer(static_cast(ATTR_NUM_TOKENS_INDEX)); + auto numRanksPtr = attrs->GetAttrPointer(static_cast(ATTR_NUM_RANKS_INDEX)); + auto numExpertsPtr = attrs->GetAttrPointer(ATTR_NUM_EXPERTS_INDEX); + auto numTopkPtr = attrs->GetAttrPointer(static_cast(ATTR_NUM_TOPK_INDEX)); + auto localRankSizePtr = attrs->GetAttrPointer(static_cast(ATTR_LOCAL_RANKSIZE_INDEX)); + + OP_TILING_CHECK(numTokensPtr == nullptr, OP_LOGE(nodeName, "numTokensPtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(numRanksPtr == nullptr, OP_LOGE(nodeName, "numRanksPtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(numExpertsPtr == nullptr, OP_LOGE(nodeName, "numExpertsPtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(numTopkPtr == nullptr, OP_LOGE(nodeName, "numTopkPtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(localRankSizePtr == nullptr, OP_LOGE(nodeName, "localRankSizePtr is null."), + return ge::GRAPH_FAILED); + + OP_TILING_CHECK((*numRanksPtr <= 0) || (*numRanksPtr > MAX_COMM_WORLD_SIZE), + OP_LOGE(nodeName, "rankSize is invalid, only support (0, %ld], but got rankSize=%ld.", + MAX_COMM_WORLD_SIZE, *numRanksPtr), + return ge::GRAPH_FAILED); + OP_TILING_CHECK((*numExpertsPtr <= 0) || (*numExpertsPtr > MAX_MOE_EXPERTS_NUM), + OP_LOGE(nodeName, "numExperts is invalid, only support (0, %ld], but got numExperts=%ld.", + MAX_MOE_EXPERTS_NUM, *numExpertsPtr), + return ge::GRAPH_FAILED); + OP_TILING_CHECK( + (*numTopkPtr <= 0) || (*numTopkPtr > K_MAX), + OP_LOGE(nodeName, "numTopkPtr is invalid, only support (0, %u], but got numTopk=%ld.", K_MAX, *numTopkPtr), + return ge::GRAPH_FAILED); + + if (CheckIfA2Machine(context)) { + OP_TILING_CHECK( + (*localRankSizePtr <= 0) || (*localRankSizePtr > MAX_A2_LOCAL_RANKSIZE), + OP_LOGE(nodeName, "localRankSizePtr is invalid, only support (0, %ld], but got localRankSize=%ld.", + MAX_A2_LOCAL_RANKSIZE, *localRankSizePtr), + return ge::GRAPH_FAILED); + } + + tilingData.dispatchLayoutInfo.numTokens = static_cast(*numTokensPtr); + tilingData.dispatchLayoutInfo.numRanks = static_cast(*numRanksPtr); + tilingData.dispatchLayoutInfo.numExperts = static_cast(*numExpertsPtr); + tilingData.dispatchLayoutInfo.numTopk = static_cast(*numTopkPtr); + tilingData.dispatchLayoutInfo.localRankSize = static_cast(*localRankSizePtr); + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName) +{ + size_t *workSpaces = context->GetWorkspaceSizes(1); + OP_TILING_CHECK(workSpaces == nullptr, OP_LOGE(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED); + workSpaces[0] = SYSTEM_NEED_WORKSPACE + KERNEL_USE_WORKSPACE + KERNEL_A2_ARG_SIZE; + return ge::GRAPH_SUCCESS; +} + +static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeName) +{ + auto topkIdx = context->GetInputDesc(INPUT_TOPK_IDX_INDEX); + auto numTokensPerRank = context->GetOutputDesc(OUTPUT_NUM_TOKEN_PER_RANK_INDEX); + auto numTokensPerExpert = context->GetOutputDesc(OUTPUT_NUM_TOKEN_PER_EXPERT_INDEX); + auto isTokenInRank = context->GetOutputDesc(OUTPUT_IS_TOKEN_IN_RANK_INDEX); + auto totalData = context->GetOutputDesc(OUTPUT_TOTAL_DATA_INDEX); + + OP_TILING_CHECK(topkIdx == nullptr, OP_LOGE(nodeName, "topkIdx is null."), return false); + OP_TILING_CHECK(numTokensPerRank == nullptr, OP_LOGE(nodeName, "numTokensPerRank is null."), return false); + OP_TILING_CHECK(numTokensPerExpert == nullptr, OP_LOGE(nodeName, "numTokensPerExpert is null."), return false); + OP_TILING_CHECK(isTokenInRank == nullptr, OP_LOGE(nodeName, "isTokenInRank is null."), return false); + OP_TILING_CHECK(totalData == nullptr, OP_LOGE(nodeName, "totalData is null."), return false); + + OP_TILING_CHECK((topkIdx->GetDataType() != ge::DT_INT64), + OP_LOGE(nodeName, "topkIdx datatype is invalid, datatype should be int, but is %d.", + static_cast(topkIdx->GetDataType())), + return false); + OP_TILING_CHECK((numTokensPerRank->GetDataType() != ge::DT_INT32), + OP_LOGE(nodeName, "numTokensPerRank datatype is invalid, datatype should be int, but is %d.", + static_cast(numTokensPerRank->GetDataType())), + return false); + OP_TILING_CHECK((numTokensPerExpert->GetDataType() != ge::DT_INT32), + OP_LOGE(nodeName, "numTokensPerExpert datatype is invalid, datatype should be int, but is %d.", + static_cast(numTokensPerExpert->GetDataType())), + return false); + OP_TILING_CHECK((isTokenInRank->GetDataType() != ge::DT_INT32), + OP_LOGE(nodeName, "isTokenInRank datatype is invalid, datatype should be int, but is %d.", + static_cast(isTokenInRank->GetDataType())), + return false); + OP_TILING_CHECK((totalData->GetDataType() != ge::DT_INT32), + OP_LOGE(nodeName, "totalData datatype is invalid, datatype should be int, but is %d.", + static_cast(totalData->GetDataType())), + return false); + + return true; +} + +static bool CheckTensorShape(gert::TilingContext *context, const char *nodeName) +{ + const gert::StorageShape *topkIdxStorageShape = context->GetInputShape(INPUT_TOPK_IDX_INDEX); + int64_t topkIdxDim0 = topkIdxStorageShape->GetStorageShape().GetDim(0); + int64_t topkIdxDim1 = topkIdxStorageShape->GetStorageShape().GetDim(1); + + OP_TILING_CHECK((topkIdxStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS), + OP_LOGE(nodeName, "topkIdx must be 2-dimension, but get %lu dim.", + topkIdxStorageShape->GetStorageShape().GetDimNum()), + return false); + + return true; +} + +static ge::graphStatus TilingCheckTensor(gert::TilingContext *context, const char *nodeName) +{ + OP_TILING_CHECK(!CheckTensorDataType(context, nodeName), OP_LOGE(nodeName, "params dataType is invalid."), + return ge::GRAPH_FAILED); + + OP_TILING_CHECK(!CheckTensorShape(context, nodeName), OP_LOGE(nodeName, "params dataType is invalid."), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus DispatchLayoutTilingFuncImpl(gert::TilingContext *context) +{ + const char *nodeName = context->GetNodeName(); + DispatchLayoutTilingData *tilingData = context->GetTilingData(); + OP_TILING_CHECK(tilingData == nullptr, OP_LOGE(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED); + OP_LOGI(nodeName, "Enter NotifyDispatch tiling check func."); + + OP_TILING_CHECK(GetAttrAndSetTilingData(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, + OP_LOGE(nodeName, "Get attr and set tiling data failed."), return ge::GRAPH_FAILED); + + OP_TILING_CHECK(TilingCheckTensor(context, nodeName) != ge::GRAPH_SUCCESS, + OP_LOGE(nodeName, "Tiling check param failed."), return ge::GRAPH_FAILED); + + OP_TILING_CHECK(SetWorkSpace(context, nodeName) != ge::GRAPH_SUCCESS, + OP_LOGE(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED); + + int tilingKey = TILING_KEY_INT; + if (CheckIfA2Machine(context)) { + tilingKey = tilingKey + TILING_KEY_A2_TYPE; + } + context->SetTilingKey(tilingKey); + + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + uint32_t blockDim; + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + uint64_t ubSize = 0UL; + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); + + blockDim = aivNum; + context->SetBlockDim(blockDim); + tilingData->dispatchLayoutInfo.totalUbSize = ubSize; + OP_LOGD(nodeName, "blockDim=%u, aivNum=%u, ubSize=%lu", blockDim, aivNum, ubSize); + PrintTilingDataInfo(nodeName, *tilingData); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus DispatchLayoutTilingFunc(gert::TilingContext *context) +{ + fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo(); + fe::PlatFormInfos &platformInfo = *platformInfoPtr; + + std::string socVersion; + ge::graphStatus ret; + ret = DispatchLayoutTilingFuncImpl(context); + return ret; +} + +struct DispatchLayoutCompileInfo {}; +ge::graphStatus TilingParseForDispatchLayout(gert::TilingParseContext *context) +{ + (void)context; + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(DispatchLayout) + .Tiling(DispatchLayoutTilingFunc) + .TilingParse(TilingParseForDispatchLayout); +} // namespace optiling diff --git a/csrc/deepep/ops2/op_host/dispatch_normal_a2.cpp b/csrc/deepep/ops2/op_host/dispatch_normal_a2.cpp new file mode 100644 index 00000000..9eacea5b --- /dev/null +++ b/csrc/deepep/ops2/op_host/dispatch_normal_a2.cpp @@ -0,0 +1,147 @@ +#include "register/op_def_registry.h" + +namespace ops { +class DispatchNormalA2 : public OpDef +{ +public: + explicit DispatchNormalA2(const char *name) : OpDef(name) + { + this->Input("x") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("expert_ids") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("scales") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("x_active_mask") + .ParamType(OPTIONAL) + .DataType({ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("expert_scales") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("tokenServerIdx") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("tokenServerCnt") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("epRankTokenCnt") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("srcOffsetRankTokenIdx") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("dstOffsetRankTokenIdx") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + + this->Output("recv_x") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_INT8, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("dynamic_scales") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("expand_idx") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + + this->Output("expert_token_nums") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + + this->Output("ep_recv_count") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + + this->Output("expand_scales") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + + this->Output("dispatch_wait_recv_cost_stats") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + + this->Attr("group_ep").AttrType(REQUIRED).String(); + this->Attr("ep_world_size").AttrType(REQUIRED).Int(); + this->Attr("ep_rank_id").AttrType(REQUIRED).Int(); + this->Attr("moe_expert_num").AttrType(REQUIRED).Int(); + this->Attr("group_tp").AttrType(OPTIONAL).String(""); + this->Attr("tp_world_size").AttrType(OPTIONAL).Int(0); + this->Attr("tp_rank_id").AttrType(OPTIONAL).Int(0); + this->Attr("expert_shard_type").AttrType(OPTIONAL).Int(0); + this->Attr("shared_expert_num").AttrType(OPTIONAL).Int(1); + this->Attr("shared_expert_rank_num").AttrType(OPTIONAL).Int(0); + this->Attr("quant_mode").AttrType(OPTIONAL).Int(0); + this->Attr("global_bs").AttrType(OPTIONAL).Int(0); + this->Attr("expert_token_nums_type").AttrType(OPTIONAL).Int(1); + + OpAICoreConfig aicore_config_base; + aicore_config_base.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true) + .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") + .ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel"); + + OpAICoreConfig aicore_config_A2 = aicore_config_base; + aicore_config_A2.ExtendCfgInfo("jitCompile.flag", "static_false"); + + OpAICoreConfig aicore_config = aicore_config_base; + aicore_config.ExtendCfgInfo("jitCompile.flag", "static_true"); + + this->AICore().AddConfig("ascend910_93", aicore_config); + this->AICore().AddConfig("ascend910b", aicore_config_A2); + this->MC2().HcclGroup("group_ep"); + } +}; + +OP_ADD(DispatchNormalA2); +} // namespace ops diff --git a/csrc/deepep/ops2/op_host/dispatch_normal_a2_tiling.cpp b/csrc/deepep/ops2/op_host/dispatch_normal_a2_tiling.cpp new file mode 100644 index 00000000..128415b0 --- /dev/null +++ b/csrc/deepep/ops2/op_host/dispatch_normal_a2_tiling.cpp @@ -0,0 +1,1140 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "graph/utils/type_utils.h" +#include "register/op_def_registry.h" +#include "tiling/platform/platform_ascendc.h" +#include "tiling/hccl/hccl_tiling.h" +#include "experiment/platform/platform/platform_infos_def.h" +#include "error_log.h" +#include "../op_kernel/cam_moe_distribute_dispatch_tiling.h" +#include "tiling_args.h" + +using namespace AscendC; +using namespace ge; +using namespace Cam; + +namespace { +class Mc2TilingUtils +{ +public: +#define HCCL_BUFFSIZE "HCCL_BUFFSIZE" + static uint64_t GetMaxWindowSize() + { + uint16_t defaultWindowSize = 200; + if (getenv(HCCL_BUFFSIZE) == nullptr) { + OP_LOGD("", "Env HCCL_BUFFSIZE don't set"); + } else { + try { + std::string envStr(getenv(HCCL_BUFFSIZE)); + defaultWindowSize = std::stoi(envStr); + } catch (const std::invalid_argument &ia) { + OP_LOGE("", "Invalid argument when parsing HCCL_BUFFSIZE: %s", ia.what()); + } catch (const std::out_of_range &oor) { + OP_LOGE("", "Out of range when parsing HCCL_BUFFSIZE: %s", oor.what()); + } + } + const uint64_t maxWindowSize = static_cast(defaultWindowSize) * 1024UL * 1024UL; + OP_LOGI("", "Get maxWindowSize is %lu", maxWindowSize); + return maxWindowSize; + } +}; +constexpr uint32_t X_INDEX = 0; +constexpr uint32_t EXPERT_IDS_INDEX = 1; +constexpr uint32_t SCALES_INDEX = 2; + +constexpr uint32_t TOKEN_SERVER_IDX_INDEX = 5; +constexpr uint32_t TOKEN_SERVER_CNT_INDEX = 6; +constexpr uint32_t EP_RANK_TOKEN_CNT_INDEX = 7; +constexpr uint32_t SRC_OFFSET_RANK_TOKEN_IDX_INDEX = 8; +constexpr uint32_t DST_OFFSET_RANK_TOKEN_IDX_INDEX = 9; +constexpr uint32_t OUTPUT_EXPAND_X_INDEX = 0; +constexpr uint32_t OUTPUT_DYNAMIC_SCALES_INDEX = 1; +constexpr uint32_t OUTPUT_EXPAND_IDX_INDEX = 2; +constexpr uint32_t OUTPUT_EXPERT_TOKEN_NUMS_INDEX = 3; +constexpr uint32_t OUTPUT_EP_RECV_COUNTS_INDEX = 4; +constexpr uint32_t OUTPUT_TP_RECV_COUNTS_INDEX = 5; + +constexpr uint32_t ATTR_GROUP_EP_INDEX = 0; +constexpr uint32_t ATTR_EP_WORLD_SIZE_INDEX = 1; +constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2; +constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 3; +constexpr uint32_t ATTR_GROUP_TP_INDEX = 4; +constexpr uint32_t ATTR_TP_WORLD_SIZE_INDEX = 5; +constexpr uint32_t ATTR_TP_RANK_ID_INDEX = 6; +constexpr uint32_t ATTR_EXPERT_SHARD_TYPE_INDEX = 7; +constexpr uint32_t ATTR_SHARED_EXPERT_NUM_INDEX = 8; +constexpr uint32_t ATTR_SHARED_EXPERT_RANK_NUM_INDEX = 9; +constexpr uint32_t ATTR_QUANT_MODE_INDEX = 10; +constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 11; +constexpr uint32_t ATTR_EXPERT_TOKEN_NUMS_TYPE_INDEX = 12; + +constexpr uint32_t TWO_DIMS = 2; +constexpr uint32_t ONE_DIM = 1; +constexpr uint32_t DYN_SCALE_DIMS = 1; +constexpr uint32_t EXPAND_IDX_DIMS = 1; +constexpr uint32_t DYNAMIC_SCALE_DIM_NUM = 1; +constexpr uint64_t INIT_TILINGKEY = 1000; +constexpr uint32_t ARR_LENGTH = 128; +constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8; +constexpr uint32_t NO_SCALES = 0; +constexpr uint32_t STATIC_SCALES = 1; +constexpr uint32_t DYNAMIC_SCALES = 2; +constexpr uint32_t OP_TYPE_ALL_GATHER = 6; + +constexpr uint32_t UNQUANT_MODE = 0; +constexpr uint32_t STATIC_QUANT_MODE = 1; +constexpr uint32_t DYNAMIC_QUANT_MODE = 2; +constexpr uint32_t RANK_NUM_PER_NODE_A2 = 8; +constexpr uint32_t BLOCK_SIZE_A2 = 32; +constexpr uint32_t MAX_K_VALUE_A2 = 8; +constexpr int32_t MAX_HIDDEN_SIZE_A2 = 7168; +constexpr int32_t MAX_EP_WORLD_SIZE_A2 = 256; +constexpr int32_t MAX_MOE_EXPERT_NUMS_A2 = 512; +constexpr uint32_t SUPPORT_HIDDEN_SIZE = 7168; +const char *K_INNER_DEBUG = "CamHCommMoeDistributeDispatch Tiling Debug"; +const size_t MAX_GROUP_NAME_LENGTH = 128UL; +const int64_t MAX_EP_WORLD_SIZE = 288; +const int64_t MAX_TP_WORLD_SIZE = 2; +const int64_t BS_UPPER_BOUND = 512; + +constexpr uint32_t SHARED_EXPERT_NUM = 1; +constexpr uint64_t BUFF_NUM = 2; +constexpr uint64_t FLOAT16_SIZE = 2; +constexpr uint32_t EXPERT_TOKEN_NUM_TYPE_SUM = 0; +constexpr uint32_t EXPERT_TOKEN_NUM_TYPE_COUNT = 1; +constexpr uint32_t SCALES_TILING_KEY = 10; +constexpr uint32_t TP_TILING_KEY = 100; +constexpr uint32_t VERSION_2 = 2; +constexpr uint32_t HCOMMCNT_2 = 2; +constexpr int64_t MOE_EXPERT_MAX_NUM = 512; +constexpr int64_t K_MAX = 8; +constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024; +constexpr uint32_t USER_WORKSPACE_A2 = 1 * 1024 * 1024; // moeExpertNum_ * sizeof(uint32_t) + epWorldSize_ * 2 * 32 +constexpr int32_t HCCL_BUFFER_SIZE_DEFAULT = 200 * 1024 * 1024; // Bytes +constexpr uint64_t MB_SIZE = 1024UL * 1024UL; + +constexpr uint64_t TILING_KEY_BASE_A2 = 2000000000; +constexpr uint64_t TILING_KEY_LAYERED_COMM_A2 = 100000000; +} // namespace + +namespace optiling { +static void PrintTilingDataInfo(const char *nodeName, const CamMoeDistributeDispatchTilingData &tilingData) +{ + OP_LOGD(nodeName, "epWorldSize is %u.", tilingData.moeDistributeDispatchInfo.epWorldSize); + OP_LOGD(nodeName, "tpWorldSize is %u.", tilingData.moeDistributeDispatchInfo.tpWorldSize); + OP_LOGD(nodeName, "epRankId is %u.", tilingData.moeDistributeDispatchInfo.epRankId); + OP_LOGD(nodeName, "tpRankId is %u.", tilingData.moeDistributeDispatchInfo.tpRankId); + OP_LOGD(nodeName, "expertShardType is %u.", tilingData.moeDistributeDispatchInfo.expertShardType); + OP_LOGD(nodeName, "sharedExpertRankNum is %u.", tilingData.moeDistributeDispatchInfo.sharedExpertRankNum); + OP_LOGD(nodeName, "moeExpertNum is %u.", tilingData.moeDistributeDispatchInfo.moeExpertNum); + OP_LOGD(nodeName, "quantMode is %u.", tilingData.moeDistributeDispatchInfo.quantMode); + OP_LOGD(nodeName, "globalBs is %u.", tilingData.moeDistributeDispatchInfo.globalBs); + OP_LOGD(nodeName, "isQuant is %d.", tilingData.moeDistributeDispatchInfo.isQuant); + OP_LOGD(nodeName, "bs is %u.", tilingData.moeDistributeDispatchInfo.bs); + OP_LOGD(nodeName, "k is %u.", tilingData.moeDistributeDispatchInfo.k); + OP_LOGD(nodeName, "h is %u.", tilingData.moeDistributeDispatchInfo.h); + OP_LOGD(nodeName, "aivNum is %u.", tilingData.moeDistributeDispatchInfo.aivNum); + OP_LOGD(nodeName, "totalUbSize is %lu.", tilingData.moeDistributeDispatchInfo.totalUbSize); + OP_LOGD(nodeName, "totalWinSize is %lu.", tilingData.moeDistributeDispatchInfo.totalWinSize); +} + +static bool CheckTensorDim(const gert::TilingContext &context, const char *nodeName, const bool isScales, + const uint32_t quantMode) +{ + const gert::StorageShape *xStorageShape = context.GetInputShape(X_INDEX); + OP_TILING_CHECK(xStorageShape == nullptr, OP_LOGE(nodeName, "xShape is null."), return false); + OP_TILING_CHECK(xStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, + OP_LOGE(nodeName, "xShape dims must be %u, but current dim num is %lu.", TWO_DIMS, + xStorageShape->GetStorageShape().GetDimNum()), + return false); + OP_LOGD(nodeName, "x dim0 = %ld", xStorageShape->GetStorageShape().GetDim(0)); + OP_LOGD(nodeName, "x dim1 = %ld", xStorageShape->GetStorageShape().GetDim(1)); + + const gert::StorageShape *expertIdStorageShape = context.GetInputShape(EXPERT_IDS_INDEX); + OP_TILING_CHECK(expertIdStorageShape == nullptr, OP_LOGE(nodeName, "expertIdShape is null."), return false); + OP_TILING_CHECK(expertIdStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, + OP_LOGE(nodeName, "expertIdShape dims must be %u, but current dim num is %lu.", TWO_DIMS, + expertIdStorageShape->GetStorageShape().GetDimNum()), + return false); + OP_LOGD(nodeName, "expertId dim0 = %ld", expertIdStorageShape->GetStorageShape().GetDim(0)); + OP_LOGD(nodeName, "expertId dim1 = %ld", expertIdStorageShape->GetStorageShape().GetDim(1)); + // 如果scales不为空进行shape维度检查 + if (isScales) { + const gert::StorageShape *scalesStorageShape = context.GetOptionalInputShape(SCALES_INDEX); + OP_TILING_CHECK(scalesStorageShape == nullptr, OP_LOGE(nodeName, "scalesShape is null."), return false); + OP_TILING_CHECK(scalesStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, + OP_LOGE(nodeName, "scalesShape dims must be %u, but current dim num is %lu.", TWO_DIMS, + scalesStorageShape->GetStorageShape().GetDimNum()), + return false); + OP_LOGD(nodeName, "scales dim0 = %ld", scalesStorageShape->GetStorageShape().GetDim(0)); + OP_LOGD(nodeName, "scales dim1 = %ld", scalesStorageShape->GetStorageShape().GetDim(1)); + } + + const gert::StorageShape *expandXStorageShape = context.GetOutputShape(OUTPUT_EXPAND_X_INDEX); + OP_TILING_CHECK(expandXStorageShape == nullptr, OP_LOGE(nodeName, "expandXShape is null."), return false); + OP_TILING_CHECK(expandXStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, + OP_LOGE(nodeName, "expandXShape dims must be %u, but current dim num is %lu.", TWO_DIMS, + expandXStorageShape->GetStorageShape().GetDimNum()), + return false); + OP_LOGD(nodeName, "expandX dim0 = %ld", expandXStorageShape->GetStorageShape().GetDim(0)); + OP_LOGD(nodeName, "expandX dim1 = %ld", expandXStorageShape->GetStorageShape().GetDim(1)); + + if (quantMode == DYNAMIC_SCALES) { + const gert::StorageShape *dynamicScalesStorageShape = context.GetOutputShape(OUTPUT_DYNAMIC_SCALES_INDEX); + OP_TILING_CHECK(dynamicScalesStorageShape == nullptr, OP_LOGE(nodeName, "dynamicScalesShape is null."), + return false); + OP_TILING_CHECK(dynamicScalesStorageShape->GetStorageShape().GetDimNum() != DYNAMIC_SCALE_DIM_NUM, + OP_LOGE(nodeName, "dynamicScalesShape dims must be %u, but current dim num is %lu.", + DYNAMIC_SCALE_DIM_NUM, dynamicScalesStorageShape->GetStorageShape().GetDimNum()), + return false); + OP_LOGD(nodeName, "dynamicScales dim0 = %ld", dynamicScalesStorageShape->GetStorageShape().GetDim(0)); + } + + const gert::StorageShape *expandIdxStorageShape = context.GetOutputShape(OUTPUT_EXPAND_IDX_INDEX); + OP_TILING_CHECK(expandIdxStorageShape == nullptr, OP_LOGE(nodeName, "expandIdxShape is null."), return false); + OP_TILING_CHECK(expandIdxStorageShape->GetStorageShape().GetDimNum() != ONE_DIM, + OP_LOGE(nodeName, "expandIdxShape dims must be %u, but current dim num is %lu.", ONE_DIM, + expandIdxStorageShape->GetStorageShape().GetDimNum()), + return false); + OP_LOGD(nodeName, "expandIdx dim0 = %ld", expandIdxStorageShape->GetStorageShape().GetDim(0)); + + const gert::StorageShape *expertTokenNumsStorageShape = context.GetOutputShape(OUTPUT_EXPERT_TOKEN_NUMS_INDEX); + OP_TILING_CHECK(expertTokenNumsStorageShape == nullptr, OP_LOGE(nodeName, "expertTokenNumsShape is null."), + return false); + OP_TILING_CHECK(expertTokenNumsStorageShape->GetStorageShape().GetDimNum() != ONE_DIM, + OP_LOGE(nodeName, "expertTokenNumsShape dims must be %u, but current dim num is %lu.", ONE_DIM, + expertTokenNumsStorageShape->GetStorageShape().GetDimNum()), + return false); + OP_LOGD(nodeName, "expertTokenNums dim0 = %ld", expertTokenNumsStorageShape->GetStorageShape().GetDim(0)); + + const gert::StorageShape *epRecvCountStorageShape = context.GetOutputShape(OUTPUT_EP_RECV_COUNTS_INDEX); + OP_TILING_CHECK(epRecvCountStorageShape == nullptr, OP_LOGE(nodeName, "epRecvCountShape is null."), return false); + OP_TILING_CHECK(epRecvCountStorageShape->GetStorageShape().GetDimNum() != ONE_DIM, + OP_LOGE(nodeName, "epRecvCountShape dims must be %u, but current dim num is %lu.", ONE_DIM, + epRecvCountStorageShape->GetStorageShape().GetDimNum()), + return false); + OP_LOGD(nodeName, "epRecvCount dim0 = %ld", epRecvCountStorageShape->GetStorageShape().GetDim(0)); + + // const gert::StorageShape *tpRecvCountStorageShape = context.GetOutputShape(OUTPUT_TP_RECV_COUNTS_INDEX); + // OP_TILING_CHECK(tpRecvCountStorageShape == nullptr, OP_LOGE(nodeName, "tpRecvCountShape is null."), return + // false); OP_TILING_CHECK(tpRecvCountStorageShape->GetStorageShape().GetDimNum() != ONE_DIM, + // OP_LOGE(nodeName, "tpRecvCountShape dims must be %u, but current dim num is %lu.", ONE_DIM, + // tpRecvCountStorageShape->GetStorageShape().GetDimNum()), + // return false); + // OP_LOGD(nodeName, "tpRecvCount dim0 = %ld", tpRecvCountStorageShape->GetStorageShape().GetDim(0)); + + return true; +} + +static bool CheckTensorDataType(const gert::TilingContext &context, const char *nodeName, const bool isScales, + const uint32_t quantMode) +{ + auto xDesc = context.GetInputDesc(X_INDEX); + OP_TILING_CHECK(xDesc == nullptr, OP_LOGE(nodeName, "xDesc is null."), return false); + OP_TILING_CHECK((xDesc->GetDataType() != ge::DT_BF16) && (xDesc->GetDataType() != ge::DT_FLOAT16), + OP_LOGE(nodeName, "x datatype is invalid, datatype should be bf16 or float16, but is %d.", + static_cast(xDesc->GetDataType())), + return false); + + auto expertIdDesc = context.GetInputDesc(EXPERT_IDS_INDEX); + OP_TILING_CHECK(expertIdDesc == nullptr, OP_LOGE(nodeName, "expertIdDesc is null."), return false); + OP_TILING_CHECK(expertIdDesc->GetDataType() != ge::DT_INT32, + OP_LOGE(nodeName, "expertId datatype is invalid, datatype should be int32, but is %d.", + static_cast(expertIdDesc->GetDataType())), + return false); + + if (isScales) { + auto scalesDesc = context.GetOptionalInputDesc(SCALES_INDEX); + OP_TILING_CHECK(scalesDesc == nullptr, OP_LOGE(nodeName, "scalesDesc is null."), return false); + OP_TILING_CHECK(scalesDesc->GetDataType() != ge::DT_FLOAT, + OP_LOGE(nodeName, "scales datatype is invalid, datatype should be float, but is %d.", + static_cast(scalesDesc->GetDataType())), + return false); + } + + auto expandXDesc = context.GetOutputDesc(OUTPUT_EXPAND_X_INDEX); + OP_TILING_CHECK(expandXDesc == nullptr, OP_LOGE(nodeName, "expandXDesc is null."), return false); + if (quantMode != NO_SCALES) { + OP_TILING_CHECK(expandXDesc->GetDataType() != ge::DT_INT8, + OP_LOGE(nodeName, "expandX datatype is invalid, datatype should be int8, but is %d.", + static_cast(expandXDesc->GetDataType())), + return false); + } else { + OP_TILING_CHECK( + expandXDesc->GetDataType() != xDesc->GetDataType(), + OP_LOGE(nodeName, "expandX dataType is invalid, dataType should be equal to x dataType %d, but is %d.", + static_cast(xDesc->GetDataType()), + static_cast(expandXDesc->GetDataType())), + return false); + } + + if (quantMode == DYNAMIC_SCALES) { + auto dynamicScalesDesc = context.GetOutputDesc(OUTPUT_DYNAMIC_SCALES_INDEX); + OP_TILING_CHECK(dynamicScalesDesc == nullptr, OP_LOGE(nodeName, "dynamicScalesDesc is null."), return false); + OP_TILING_CHECK(dynamicScalesDesc->GetDataType() != ge::DT_FLOAT, + OP_LOGE(nodeName, "dynamicScales datatype is invalid, datatype should be float, but is %d.", + static_cast(dynamicScalesDesc->GetDataType())), + return false); + } + + auto expandIdxDesc = context.GetOutputDesc(OUTPUT_EXPAND_IDX_INDEX); + OP_TILING_CHECK(expandIdxDesc == nullptr, OP_LOGE(nodeName, "expandIdxDesc is null."), return false); + OP_TILING_CHECK(expandIdxDesc->GetDataType() != ge::DT_INT32, + OP_LOGE(nodeName, "expandIdx datatype is invalid, datatype should be int32, but is %d.", + static_cast(expandIdxDesc->GetDataType())), + return false); + + auto expertTokenNumsDesc = context.GetOutputDesc(OUTPUT_EXPERT_TOKEN_NUMS_INDEX); + OP_TILING_CHECK(expertTokenNumsDesc == nullptr, OP_LOGE(nodeName, "expertTokenNumsDesc is null."), return false); + OP_TILING_CHECK(expertTokenNumsDesc->GetDataType() != ge::DT_INT64, + OP_LOGE(nodeName, "expertTokenNums datatype is invalid, datatype should be int64, but is %d.", + static_cast(expertTokenNumsDesc->GetDataType())), + return false); + + auto epRecvCountsDesc = context.GetOutputDesc(OUTPUT_EP_RECV_COUNTS_INDEX); + OP_TILING_CHECK(epRecvCountsDesc == nullptr, OP_LOGE(nodeName, "epRecvCountsDesc is null."), return false); + OP_TILING_CHECK(epRecvCountsDesc->GetDataType() != ge::DT_INT32, + OP_LOGE(nodeName, "epRecvCounts datatype is invalid, datatype should be int32, but is %d.", + static_cast(epRecvCountsDesc->GetDataType())), + return false); + + // auto tpRecvCountsDesc = context.GetOutputDesc(OUTPUT_TP_RECV_COUNTS_INDEX); + // OP_TILING_CHECK(tpRecvCountsDesc == nullptr, OP_LOGE(nodeName, "tpRecvCountsDesc is null."), return false); + // OP_TILING_CHECK(tpRecvCountsDesc->GetDataType() != ge::DT_INT32, + // OP_LOGE(nodeName, "tpRecvCounts datatype is invalid, datatype should be int32, but is %d.", + // static_cast(tpRecvCountsDesc->GetDataType())), + // return false); + return true; +} + +static bool CheckTensorFormat(const gert::TilingContext &context, const char *nodeName, const bool isScales, + const uint32_t quantMode) +{ + auto xDesc = context.GetInputDesc(X_INDEX); + OP_TILING_CHECK(xDesc == nullptr, OP_LOGE(nodeName, "xDesc is null."), return false); + OP_TILING_CHECK(static_cast(ge::GetPrimaryFormat(xDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ, + OP_LOGE(nodeName, "x format is invalid."), return false); + + auto expertIdDesc = context.GetInputDesc(EXPERT_IDS_INDEX); + OP_TILING_CHECK(expertIdDesc == nullptr, OP_LOGE(nodeName, "expertIdDesc is null."), return false); + OP_TILING_CHECK( + static_cast(ge::GetPrimaryFormat(expertIdDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ, + OP_LOGE(nodeName, "expertId format is invalid."), return false); + + if (isScales) { + auto scalesDesc = context.GetOptionalInputDesc(SCALES_INDEX); + OP_TILING_CHECK(scalesDesc == nullptr, OP_LOGE(nodeName, "scalesDesc is null."), return false); + OP_TILING_CHECK( + static_cast(ge::GetPrimaryFormat(scalesDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ, + OP_LOGE(nodeName, "scales format is invalid."), return false); + } + + auto expandXDesc = context.GetOutputDesc(OUTPUT_EXPAND_X_INDEX); + OP_TILING_CHECK(expandXDesc == nullptr, OP_LOGE(nodeName, "expandXDesc is null."), return false); + OP_TILING_CHECK( + static_cast(ge::GetPrimaryFormat(expandXDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ, + OP_LOGE(nodeName, "expandX format is invalid."), return false); + + if (quantMode == DYNAMIC_SCALES) { + auto dynamicScalesDesc = context.GetOutputDesc(OUTPUT_DYNAMIC_SCALES_INDEX); + OP_TILING_CHECK(dynamicScalesDesc == nullptr, OP_LOGE(nodeName, "dynamicScalesDesc is null."), return false); + OP_TILING_CHECK(static_cast(ge::GetPrimaryFormat(dynamicScalesDesc->GetStorageFormat())) == + ge::FORMAT_FRACTAL_NZ, + OP_LOGE(nodeName, "dynamicScales format is invalid."), return false); + } + + auto expandIdxDesc = context.GetOutputDesc(OUTPUT_EXPAND_IDX_INDEX); + OP_TILING_CHECK(expandIdxDesc == nullptr, OP_LOGE(nodeName, "expandIdxDesc is null."), return false); + OP_TILING_CHECK( + static_cast(ge::GetPrimaryFormat(expandIdxDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ, + OP_LOGE(nodeName, "expandIdx format is invalid."), return false); + + auto expertTokenNumsDesc = context.GetOutputDesc(OUTPUT_EXPERT_TOKEN_NUMS_INDEX); + OP_TILING_CHECK(expertTokenNumsDesc == nullptr, OP_LOGE(nodeName, "expertTokenNumsDesc is null."), return false); + OP_TILING_CHECK( + static_cast(ge::GetPrimaryFormat(expertTokenNumsDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ, + OP_LOGE(nodeName, "expertTokenNums format is invalid."), return false); + + auto epRecvCountsDesc = context.GetOutputDesc(OUTPUT_EP_RECV_COUNTS_INDEX); + OP_TILING_CHECK(epRecvCountsDesc == nullptr, OP_LOGE(nodeName, "epRecvCountsDesc is null."), return false); + OP_TILING_CHECK( + static_cast(ge::GetPrimaryFormat(epRecvCountsDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ, + OP_LOGE(nodeName, "epRecvCounts format is invalid."), return false); + + // auto tpRecvCountsDesc = context.GetOutputDesc(OUTPUT_TP_RECV_COUNTS_INDEX); + // OP_TILING_CHECK(tpRecvCountsDesc == nullptr, OP_LOGE(nodeName, "tpRecvCountsDesc is null."), return false); + // OP_TILING_CHECK( + // static_cast(ge::GetPrimaryFormat(tpRecvCountsDesc->GetStorageFormat())) == ge::FORMAT_FRACTAL_NZ, + // OP_LOGE(nodeName, "tpRecvCounts format is invalid."), return false); + return true; +} + +static ge::graphStatus GetAttrAndSetTilingData(const gert::TilingContext &context, const char *nodeName, + CamMoeDistributeDispatchTilingData &tilingData, std::string &groupEp, + std::string &groupTp) +{ + auto attrs = context.GetAttrs(); + OP_TILING_CHECK(attrs == nullptr, OP_LOGE(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED); + + auto groupEpPtr = attrs->GetAttrPointer(static_cast(ATTR_GROUP_EP_INDEX)); + auto groupTpPtr = attrs->GetAttrPointer(static_cast(ATTR_GROUP_TP_INDEX)); + auto epWorldSizePtr = attrs->GetAttrPointer(ATTR_EP_WORLD_SIZE_INDEX); + auto tpWorldSizePtr = attrs->GetAttrPointer(ATTR_TP_WORLD_SIZE_INDEX); + auto epRankIdPtr = attrs->GetAttrPointer(ATTR_EP_RANK_ID_INDEX); + auto tpRankIdPtr = attrs->GetAttrPointer(ATTR_TP_RANK_ID_INDEX); + auto expertShardPtr = attrs->GetAttrPointer(ATTR_EXPERT_SHARD_TYPE_INDEX); + auto sharedExpertRankNumPtr = attrs->GetAttrPointer(ATTR_SHARED_EXPERT_RANK_NUM_INDEX); + auto moeExpertNumPtr = attrs->GetAttrPointer(ATTR_MOE_EXPERT_NUM_INDEX); + auto quantModePtr = attrs->GetAttrPointer(ATTR_QUANT_MODE_INDEX); + auto sharedExpertNumPtr = attrs->GetAttrPointer(static_cast(ATTR_SHARED_EXPERT_NUM_INDEX)); + auto expertTokenNumsTypePtr = attrs->GetAttrPointer(static_cast(ATTR_EXPERT_TOKEN_NUMS_TYPE_INDEX)); + // 判空 + OP_TILING_CHECK((groupEpPtr == nullptr) || (strnlen(groupEpPtr, MAX_GROUP_NAME_LENGTH) == 0) || + (strnlen(groupEpPtr, MAX_GROUP_NAME_LENGTH) == MAX_GROUP_NAME_LENGTH), + OP_LOGE(nodeName, "groupEpPtr is null or invalid."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(epWorldSizePtr == nullptr, OP_LOGE(nodeName, "epWorldSizePtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(tpWorldSizePtr == nullptr, OP_LOGE(nodeName, "tpWorldSizePtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(epRankIdPtr == nullptr, OP_LOGE(nodeName, "epRankIdPtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(tpRankIdPtr == nullptr, OP_LOGE(nodeName, "tpRankIdPtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(expertShardPtr == nullptr, OP_LOGE(nodeName, "expertShardPtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(sharedExpertRankNumPtr == nullptr, OP_LOGE(nodeName, "sharedExpertRankNumPtr is null."), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(moeExpertNumPtr == nullptr, OP_LOGE(nodeName, "moeExpertNumPtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(quantModePtr == nullptr, OP_LOGE(nodeName, "quantModePtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(sharedExpertNumPtr == nullptr, OP_LOGE(nodeName, "sharedExpertNum is null."), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(expertTokenNumsTypePtr == nullptr, OP_LOGE(nodeName, "expertTokenNumsType is null."), + return ge::GRAPH_FAILED); + // 判断是否满足uint32_t及其他限制 + OP_TILING_CHECK((*epWorldSizePtr <= 0) || (*epWorldSizePtr > MAX_EP_WORLD_SIZE), + OP_LOGE(nodeName, "epWorldSize is invalid, only support (0, %ld], but got epWorldSize=%ld.", + MAX_EP_WORLD_SIZE, *epWorldSizePtr), + return ge::GRAPH_FAILED); + OP_TILING_CHECK((*tpWorldSizePtr < 0) || (*tpWorldSizePtr > MAX_TP_WORLD_SIZE), + OP_LOGE(nodeName, "tpWorldSize is invalid, only support [0, %ld], but got tpWorldSize=%ld.", + MAX_TP_WORLD_SIZE, *tpWorldSizePtr), + return ge::GRAPH_FAILED); + OP_TILING_CHECK((*epRankIdPtr < 0) || (*epRankIdPtr >= *epWorldSizePtr), + OP_LOGE(nodeName, "epRankId is invalid, only support [0, %ld), but got epRankId=%ld.", + *epWorldSizePtr, *epRankIdPtr), + return ge::GRAPH_FAILED); + if (*tpWorldSizePtr > 1) { + OP_TILING_CHECK((*tpRankIdPtr < 0) || (*tpRankIdPtr >= *tpWorldSizePtr), + OP_LOGE(nodeName, "tpRankId is invalid, only support [0, %ld), but got tpRankId=%ld.", + *tpWorldSizePtr, *tpRankIdPtr), + return ge::GRAPH_FAILED); + OP_TILING_CHECK((groupTpPtr == nullptr) || (strnlen(groupTpPtr, MAX_GROUP_NAME_LENGTH) == 0) || + (strnlen(groupTpPtr, MAX_GROUP_NAME_LENGTH) == MAX_GROUP_NAME_LENGTH), + OP_LOGE(nodeName, "groupTpPtr is null or invalid.."), return ge::GRAPH_FAILED); + groupTp = std::string(groupTpPtr); + } else { + OP_TILING_CHECK( + *tpRankIdPtr != 0, + OP_LOGE(nodeName, "tpRankId is invalid, NoTp mode only support 0, but got tpRankId=%ld.", *tpRankIdPtr), + return ge::GRAPH_FAILED); + } + OP_TILING_CHECK( + *expertShardPtr != 0, + OP_LOGE(nodeName, "expertShardType is invalid, only support 0, but got expertShardType=%ld.", *expertShardPtr), + return ge::GRAPH_FAILED); + OP_TILING_CHECK( + (*sharedExpertRankNumPtr < 0) || (*sharedExpertRankNumPtr >= *epWorldSizePtr), + OP_LOGE(nodeName, "sharedExpertRankNum is invalid, only support [0, %ld), but got sharedExpertRankNum=%ld.", + *epWorldSizePtr, *sharedExpertRankNumPtr), + return ge::GRAPH_FAILED); + OP_TILING_CHECK((*moeExpertNumPtr <= 0) || (*moeExpertNumPtr > MOE_EXPERT_MAX_NUM), + OP_LOGE(nodeName, "moeExpertNum is invalid, only support (0, %ld], but got moeExpertNum=%ld.", + MOE_EXPERT_MAX_NUM, *moeExpertNumPtr), + return ge::GRAPH_FAILED); + OP_TILING_CHECK( + (*quantModePtr < static_cast(NO_SCALES)) || (*quantModePtr > static_cast(DYNAMIC_SCALES)), + OP_LOGE(nodeName, "quantMode is invalid, only support [0, %u], but got quantMode=%ld.", DYNAMIC_SCALES, + *quantModePtr), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(*sharedExpertNumPtr != SHARED_EXPERT_NUM, + OP_LOGE(nodeName, "sharedExpertNum only support %u, but got sharedExpertNum=%ld.", + SHARED_EXPERT_NUM, *sharedExpertNumPtr), + return ge::GRAPH_FAILED); + OP_TILING_CHECK((*expertTokenNumsTypePtr != EXPERT_TOKEN_NUM_TYPE_SUM) && + (*expertTokenNumsTypePtr != EXPERT_TOKEN_NUM_TYPE_COUNT), + OP_LOGE(nodeName, "expertTokenNumsType only support 0 or 1, but got expertTokenNumsType=%ld.", + *expertTokenNumsTypePtr), + return ge::GRAPH_FAILED); + + groupEp = std::string(groupEpPtr); + tilingData.moeDistributeDispatchInfo.epWorldSize = static_cast(*epWorldSizePtr); + tilingData.moeDistributeDispatchInfo.tpWorldSize = static_cast(*tpWorldSizePtr); + tilingData.moeDistributeDispatchInfo.epRankId = static_cast(*epRankIdPtr); + tilingData.moeDistributeDispatchInfo.tpRankId = static_cast(*tpRankIdPtr); + tilingData.moeDistributeDispatchInfo.expertShardType = static_cast(*expertShardPtr); + tilingData.moeDistributeDispatchInfo.sharedExpertRankNum = static_cast(*sharedExpertRankNumPtr); + tilingData.moeDistributeDispatchInfo.moeExpertNum = static_cast(*moeExpertNumPtr); + tilingData.moeDistributeDispatchInfo.quantMode = static_cast(*quantModePtr); + tilingData.moeDistributeDispatchInfo.expertTokenNumsType = static_cast(*expertTokenNumsTypePtr); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus CheckAttrs(const gert::TilingContext &context, const char *nodeName, + CamMoeDistributeDispatchTilingData &tilingData, uint32_t &localMoeExpertNum) +{ + uint32_t epWorldSize = tilingData.moeDistributeDispatchInfo.epWorldSize; + uint32_t tpWorldSize = tilingData.moeDistributeDispatchInfo.tpWorldSize; + uint32_t moeExpertNum = tilingData.moeDistributeDispatchInfo.moeExpertNum; + uint32_t sharedExpertRankNum = tilingData.moeDistributeDispatchInfo.sharedExpertRankNum; + // 校验ep能否均分共享专家 + OP_TILING_CHECK((sharedExpertRankNum != 0) && (epWorldSize % sharedExpertRankNum != 0), + OP_LOGE(nodeName, + "epWorldSize should be divisible by sharedExpertRankNum, but epWorldSize=%u, " + "sharedExpertRankNum=%u.", + epWorldSize, sharedExpertRankNum), + return ge::GRAPH_FAILED); + // 校验moe专家数量能否均分给多机 + localMoeExpertNum = moeExpertNum / (epWorldSize - sharedExpertRankNum); + OP_TILING_CHECK(moeExpertNum % (epWorldSize - sharedExpertRankNum) != 0, + OP_LOGE(nodeName, + "moeExpertNum should be divisible by (epWorldSize - sharedExpertRankNum), " + "but moeExpertNum=%u, epWorldSize=%u, sharedExpertRankNum=%u.", + moeExpertNum, epWorldSize, sharedExpertRankNum), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(localMoeExpertNum <= 0, + OP_LOGE(nodeName, "localMoeExpertNum is invalid, localMoeExpertNum = %u", localMoeExpertNum), + return ge::GRAPH_FAILED); + OP_TILING_CHECK((tpWorldSize > 1) && (localMoeExpertNum > 1), + OP_LOGE(nodeName, + "Cannot support multi-moeExpert %u " + "in a rank when tpWorldSize = %u > 1", + localMoeExpertNum, tpWorldSize), + return ge::GRAPH_FAILED); + // 检验epWorldSize是否是8的倍数 + OP_TILING_CHECK(epWorldSize % 8 != 0, + OP_LOGE(nodeName, "epWorldSize should be divisible by 8, but got epWorldSize = %u.", epWorldSize), + return ge::GRAPH_FAILED); + + OP_TILING_CHECK( + (256 % epWorldSize != 0) && (epWorldSize % 144 != 0), + OP_LOGE(nodeName, + "epWorldSize should be in the list[8, 16, 32, 64, 128, 144, 256, 288], but got epWorldSize = %u.", + epWorldSize), + return ge::GRAPH_FAILED); + // 校验输入x的dim 0并设bs + const gert::StorageShape *xStorageShape = context.GetInputShape(X_INDEX); + OP_TILING_CHECK(xStorageShape == nullptr, OP_LOGE(nodeName, "xStorageShape is nullptr."), return ge::GRAPH_FAILED); + const int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0); + OP_TILING_CHECK((xDim0 > BS_UPPER_BOUND) || (xDim0 <= 0), + OP_LOGE(nodeName, "xDim0(BS) is invalid. Should be between [1, %ld], but got xDim0=%ld.", + BS_UPPER_BOUND, xDim0), + return ge::GRAPH_FAILED); + tilingData.moeDistributeDispatchInfo.bs = static_cast(xDim0); + // 校验globalBS + auto attrs = context.GetAttrs(); + OP_TILING_CHECK(attrs == nullptr, OP_LOGE(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED); + auto globalBsPtr = attrs->GetAttrPointer(ATTR_GLOBAL_BS_INDEX); + OP_TILING_CHECK(globalBsPtr == nullptr, OP_LOGE(nodeName, "globalBsPtr is nullptr."), return ge::GRAPH_FAILED); + OP_LOGD(nodeName, "CamHCommMoeDistributeDispatch *globalBsPtr = %ld, bs = %ld, epWorldSize = %u\n", *globalBsPtr, + xDim0, epWorldSize); + OP_TILING_CHECK( + (*globalBsPtr != 0) && ((*globalBsPtr < xDim0 * static_cast(epWorldSize)) || + ((*globalBsPtr) % (static_cast(epWorldSize)) != 0)), + OP_LOGE(nodeName, + "globalBS is invalid, only " + "support 0 or maxBs(maxBs is the largest bs on all ranks) * epWorldSize, but got globalBS=%ld, " + "bs=%ld, epWorldSize=%u.", + *globalBsPtr, xDim0, epWorldSize), + return ge::GRAPH_FAILED); + if (*globalBsPtr == 0) { + tilingData.moeDistributeDispatchInfo.globalBs = static_cast(xDim0) * epWorldSize; + } else { + tilingData.moeDistributeDispatchInfo.globalBs = static_cast(*globalBsPtr); + } + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus CheckTensorShape(const gert::TilingContext &context, const char *nodeName, + CamMoeDistributeDispatchTilingData &tilingData, const uint32_t quantMode, + const bool isScales, const bool isSharedExpert, const int64_t localMoeExpertNum) +{ + uint32_t A = 0; + uint32_t globalBs = tilingData.moeDistributeDispatchInfo.globalBs; + uint32_t sharedExpertRankNum = tilingData.moeDistributeDispatchInfo.sharedExpertRankNum; + // 校验输入x的维度1并设h, bs已校验过 + const gert::StorageShape *xStorageShape = context.GetInputShape(X_INDEX); + OP_TILING_CHECK(xStorageShape == nullptr, OP_LOGE(nodeName, "xStorageShape is nullptr."), return ge::GRAPH_FAILED); + const int64_t xDim0 = xStorageShape->GetStorageShape().GetDim(0); + const int64_t xDim1 = xStorageShape->GetStorageShape().GetDim(1); + OP_TILING_CHECK((xDim1 != SUPPORT_HIDDEN_SIZE), + OP_LOGE(nodeName, "xShape dims1(H) only supports %u, but got %ld.", SUPPORT_HIDDEN_SIZE, xDim1), + return ge::GRAPH_FAILED); + tilingData.moeDistributeDispatchInfo.h = static_cast(xDim1); + // 校验expert_id的维度并设k + int64_t moeExpertNum = static_cast(tilingData.moeDistributeDispatchInfo.moeExpertNum); + const gert::StorageShape *expertIdStorageShape = context.GetInputShape(EXPERT_IDS_INDEX); + OP_TILING_CHECK(expertIdStorageShape == nullptr, OP_LOGE(nodeName, "expertIdStorageShape is nullptr."), + return ge::GRAPH_FAILED); + const int64_t expertIdsDim0 = expertIdStorageShape->GetStorageShape().GetDim(0); + const int64_t expertIdsDim1 = expertIdStorageShape->GetStorageShape().GetDim(1); + OP_TILING_CHECK(xDim0 != expertIdsDim0, + OP_LOGE(nodeName, + "xShape's dim0 not equal to expertIdShape's dim0, " + "xShape's dim0 is %ld, expertIdShape's dim0 is %ld.", + xDim0, expertIdsDim0), + return ge::GRAPH_FAILED); + OP_TILING_CHECK( + (expertIdsDim1 <= 0) || (expertIdsDim1 > K_MAX), + OP_LOGE(nodeName, "expertIdShape's dim1(k) should be in (0, %ld], but got expertIdShape's dim1=%ld.", K_MAX, + expertIdsDim1), + return ge::GRAPH_FAILED); + tilingData.moeDistributeDispatchInfo.k = static_cast(expertIdsDim1); + // 校验scales的维度 + if (isScales) { + const gert::StorageShape *scalesStorageShape = context.GetOptionalInputShape(SCALES_INDEX); + OP_TILING_CHECK(scalesStorageShape == nullptr, OP_LOGE(nodeName, "scalesStorageShape is nullptr."), + return ge::GRAPH_FAILED); + const int64_t scalesDim0 = scalesStorageShape->GetStorageShape().GetDim(0); + const int64_t scalesDim1 = scalesStorageShape->GetStorageShape().GetDim(1); + if (sharedExpertRankNum == 0U) { + OP_TILING_CHECK( + scalesDim0 != moeExpertNum, + OP_LOGE(nodeName, "scales's dim0 not equal to moeExpertNum, scales's dim0 is %ld, moeExpertNum is %ld.", + scalesDim0, moeExpertNum), + return ge::GRAPH_FAILED); + } else { + OP_TILING_CHECK( + scalesDim0 != (moeExpertNum + 1UL), + OP_LOGE(nodeName, + "scales's dim0 not equal to moeExpertNum + 1, scales's dim0 is %ld, moeExpertNum + 1 is %ld.", + scalesDim0, moeExpertNum + 1UL), + return ge::GRAPH_FAILED); + } + OP_TILING_CHECK(xDim1 != scalesDim1, + OP_LOGE(nodeName, + "scales's dim1 not equal to xShape's dim1, " + "xShape's dim1 is %ld, scales's dim1 is %ld.", + xDim1, scalesDim1), + return ge::GRAPH_FAILED); + } + + if (isSharedExpert && sharedExpertRankNum != 0) { // 本卡为共享专家 + A = globalBs / sharedExpertRankNum; + } else { // 本卡为moe专家 + A = globalBs * std::min(localMoeExpertNum, expertIdsDim1); + } + // 校验expandX的维度 + int64_t tpWorldSize = static_cast(tilingData.moeDistributeDispatchInfo.tpWorldSize); + const gert::StorageShape *expandXStorageShape = context.GetOutputShape(OUTPUT_EXPAND_X_INDEX); + OP_TILING_CHECK(expandXStorageShape == nullptr, OP_LOGE(nodeName, "expandXStorageShape is nullptr."), + return ge::GRAPH_FAILED); + const int64_t expandXDim0 = expandXStorageShape->GetStorageShape().GetDim(0); + const int64_t expandXDim1 = expandXStorageShape->GetStorageShape().GetDim(1); + OP_TILING_CHECK(expandXDim0 < tpWorldSize * static_cast(A), + OP_LOGE(nodeName, + "expandX's dim0 not greater than or equal to A*tpWorldSize, " + "expandX's dim0 is %ld, A*tpWorldSize is %ld.", + expandXDim0, tpWorldSize * A), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(xDim1 != expandXDim1, + OP_LOGE(nodeName, + "expandX's dim1 not equal to xShape's dim1, " + "xShape's dim1 is %ld, expandX's dim1 is %ld.", + xDim1, expandXDim1), + return ge::GRAPH_FAILED); + // 校验dynamicScales的维度 + if (quantMode != NO_SCALES) { + const gert::StorageShape *dynamicScalesStorageShape = context.GetOutputShape(OUTPUT_DYNAMIC_SCALES_INDEX); + OP_TILING_CHECK(dynamicScalesStorageShape == nullptr, + OP_LOGE(nodeName, "dynamicScalesStorageShape is nullptr."), return ge::GRAPH_FAILED); + const int64_t dynamicScalesDim0 = dynamicScalesStorageShape->GetStorageShape().GetDim(0); + OP_TILING_CHECK(dynamicScalesDim0 < static_cast(A) * tpWorldSize, + OP_LOGE(nodeName, + "dynamicScales's dim0 should be equal to or greater than A*tpWorldSize, " + "dynamicScales's dim0 is %ld, A*tpWorldSize is %ld.", + dynamicScalesDim0, A * tpWorldSize), + return ge::GRAPH_FAILED); + } + // 校验expandIdx的维度 + const gert::StorageShape *expandIdxStorageShape = context.GetOutputShape(OUTPUT_EXPAND_IDX_INDEX); + OP_TILING_CHECK(expandIdxStorageShape == nullptr, OP_LOGE(nodeName, "expandIdxStorageShape is nullptr."), + return ge::GRAPH_FAILED); + const int64_t expandIdxDim0 = expandIdxStorageShape->GetStorageShape().GetDim(0); + OP_TILING_CHECK(expandIdxDim0 != expertIdsDim1 * xDim0, + OP_LOGE(nodeName, "expandIdxDim0 != bs * k, expandIdxDim0 is %ld, bs * k is %ld.", expandIdxDim0, + xDim0 * expertIdsDim1), + return ge::GRAPH_FAILED); + // 校验expertTokenNums的维度 + const gert::StorageShape *expertTokenNumsStorageShape = context.GetOutputShape(OUTPUT_EXPERT_TOKEN_NUMS_INDEX); + OP_TILING_CHECK(expertTokenNumsStorageShape == nullptr, + OP_LOGE(nodeName, "expertTokenNumsStorageShape is nullptr."), return ge::GRAPH_FAILED); + const int64_t expertTokenNumsDim0 = expertTokenNumsStorageShape->GetStorageShape().GetDim(0); + if (isSharedExpert) { + OP_TILING_CHECK(expertTokenNumsDim0 != ONE_DIM, + OP_LOGE(nodeName, "shared expertTokenNums's dim0 %ld not equal to 1.", expertTokenNumsDim0), + return ge::GRAPH_FAILED); + } else { + OP_TILING_CHECK( + expertTokenNumsDim0 != localMoeExpertNum, + OP_LOGE(nodeName, + "moe expertTokenNums's Dim0 not equal to localMoeExpertNum, expertTokenNumsDim0 is %ld, " + "localMoeExpertNum is %ld.", + expertTokenNumsDim0, localMoeExpertNum), + return ge::GRAPH_FAILED); + } + // 校验epRecvCount和tpRecvCount的维度 + int64_t epWorldSize = static_cast(tilingData.moeDistributeDispatchInfo.epWorldSize); + const gert::StorageShape *epRecvCountStorageShape = context.GetOutputShape(OUTPUT_EP_RECV_COUNTS_INDEX); + // const gert::StorageShape *tpRecvCountStorageShape = context.GetOutputShape(OUTPUT_TP_RECV_COUNTS_INDEX); + OP_TILING_CHECK(epRecvCountStorageShape == nullptr, OP_LOGE(nodeName, "epRecvCountStorageShape is nullptr."), + return ge::GRAPH_FAILED); + // OP_TILING_CHECK(tpRecvCountStorageShape == nullptr, OP_LOGE(nodeName, "tpRecvCountStorageShape is nullptr."), + // return ge::GRAPH_FAILED); + const int64_t epRecvCountDim0 = epRecvCountStorageShape->GetStorageShape().GetDim(0); + // const int64_t tpRecvCountDim0 = tpRecvCountStorageShape->GetStorageShape().GetDim(0); + int64_t epRecvCount = (isSharedExpert) ? epWorldSize : epWorldSize * localMoeExpertNum; + if (tpWorldSize == MAX_TP_WORLD_SIZE) { + epRecvCount *= tpWorldSize; + } + OP_TILING_CHECK( + epRecvCountDim0 < epRecvCount, + OP_LOGE( + nodeName, + "dimension 0 of epRecvCount should be greater than or equal to epWorldSize * localMoeExpertNum * " + "tpWorldSize, " + "but dimension 0 of epRecvCount is %ld, epWorldSize is %ld, localMoeExpertNum is %ld, tpWorldSize is %ld.", + epRecvCountDim0, epWorldSize, localMoeExpertNum, tpWorldSize), + return ge::GRAPH_FAILED); + // OP_TILING_CHECK( + // tpRecvCountDim0 != tpWorldSize, + // OP_LOGE(nodeName, + // "dimension 0 of tpRecvCount should be equal to tpWorldSize, but dimension 0 of tpRecvCount is %ld, " + // "tpWorldSize is %ld.", + // tpRecvCountDim0, tpWorldSize), + // return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus TilingCheckMoeDistributeDispatch(gert::TilingContext &context, const char *nodeName, + const bool isScales, const uint32_t quantMode) +{ + OP_TILING_CHECK(!CheckTensorDim(context, nodeName, isScales, quantMode), + OP_LOGE(nodeName, "params shape is invalid."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(!CheckTensorDataType(context, nodeName, isScales, quantMode), + OP_LOGE(nodeName, "params dataType is invalid."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(!CheckTensorFormat(context, nodeName, isScales, quantMode), + OP_LOGE(nodeName, "params format is invalid."), return ge::GRAPH_FAILED); + return ge::GRAPH_SUCCESS; +} + +static void CalTilingKey(uint64_t &tilingKey, const bool isScales, const uint32_t quantMode, const uint32_t tpWorldSize) +{ + tilingKey += static_cast(quantMode); + tilingKey += static_cast((isScales ? SCALES_TILING_KEY : 0)); + if (tpWorldSize == MAX_TP_WORLD_SIZE) { + tilingKey += static_cast(TP_TILING_KEY); + } + return; +} + +static void SetHcommCfg(const gert::TilingContext &context, CamMoeDistributeDispatchTilingData &tiling, + const std::string groupEp, const std::string groupTp) +{ + const char *nodeName = context.GetNodeName(); + OP_LOGD(nodeName, "CamHCommMoeDistributeDispatch groupEp = %s, groupTp = %s", groupEp.c_str(), groupTp.c_str()); + uint32_t opType1 = OP_TYPE_ALL_TO_ALL; + uint32_t opType2 = OP_TYPE_ALL_GATHER; + std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise"; + std::string algConfigAllGatherStr = "AllGather=level0:ring"; + + AscendC::Mc2CcTilingConfig mc2CcTilingConfig(groupEp, opType1, algConfigAllToAllStr); + mc2CcTilingConfig.GetTiling(tiling.mc2InitTiling); + mc2CcTilingConfig.GetTiling(tiling.mc2CcTiling1); + + mc2CcTilingConfig.SetGroupName(groupTp); + mc2CcTilingConfig.SetOpType(opType2); + mc2CcTilingConfig.SetAlgConfig(algConfigAllGatherStr); + mc2CcTilingConfig.GetTiling(tiling.mc2CcTiling2); +} + +static ge::graphStatus SetWorkSpace(gert::TilingContext &context, const char *nodeName) +{ + size_t *workSpaces = context.GetWorkspaceSizes(1); + OP_TILING_CHECK(workSpaces == nullptr, OP_LOGE(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED); + workSpaces[0] = SYSTEM_NEED_WORKSPACE; + return ge::GRAPH_SUCCESS; +} + +static bool CheckIsA2(const gert::TilingContext &context) +{ + const char *nodeName = context.GetNodeName(); + fe::PlatFormInfos *platformInfoPtr = context.GetPlatformInfo(); + OP_TILING_CHECK(platformInfoPtr == nullptr, OP_LOGE(nodeName, "platformInfoPtr is nullptr."), return 0); + fe::PlatFormInfos &platformInfo = *platformInfoPtr; + + std::string socVersion; + (void)platformInfo.GetPlatformResWithLock("version", "Short_SoC_version", socVersion); + if (socVersion == "Ascend910B") { + return true; + } + return false; +} + +static ge::graphStatus MoeDistributeDispatchA3TilingFuncImpl(gert::TilingContext &context) +{ + const char *nodeName = context.GetNodeName(); + CamMoeDistributeDispatchTilingData *tilingData = context.GetTilingData(); + OP_TILING_CHECK(tilingData == nullptr, OP_LOGE(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED); + std::string groupEp = ""; + std::string groupTp = ""; + uint32_t quantMode = NO_SCALES; + bool isScales = false; + uint32_t localMoeExpertNum = 1; + OP_LOGI(nodeName, "Enter CamHCommMoeDistributeDispatch tiling check func."); + // 获取入参属性 + OP_TILING_CHECK(GetAttrAndSetTilingData(context, nodeName, *tilingData, groupEp, groupTp) != ge::GRAPH_SUCCESS, + OP_LOGE(nodeName, "Get attr and set tiling data failed."), return ge::GRAPH_FAILED); + // 获取scales + const gert::StorageShape *scalesStorageShape = context.GetOptionalInputShape(SCALES_INDEX); + isScales = (scalesStorageShape != nullptr); + tilingData->moeDistributeDispatchInfo.isQuant = isScales; + quantMode = tilingData->moeDistributeDispatchInfo.quantMode; + // 检查quantMode和scales是否匹配 + OP_TILING_CHECK(quantMode == STATIC_SCALES, OP_LOGE(nodeName, "cannot support static quant now."), + return ge::GRAPH_FAILED); + OP_TILING_CHECK((isScales && (quantMode == NO_SCALES)) || ((!isScales) && (quantMode == STATIC_SCALES)), + OP_LOGE(nodeName, "quant mode and scales not match, isScales is %d, quantMode is %u.", + static_cast(isScales), quantMode), + return ge::GRAPH_FAILED); + // 检查输入输出的dim、format、dataType + OP_TILING_CHECK(TilingCheckMoeDistributeDispatch(context, nodeName, isScales, quantMode) != ge::GRAPH_SUCCESS, + OP_LOGE(nodeName, "Tiling check param failed."), return ge::GRAPH_FAILED); + // 检查属性的取值是否合法 + OP_TILING_CHECK(CheckAttrs(context, nodeName, *tilingData, localMoeExpertNum) != ge::GRAPH_SUCCESS, + OP_LOGE(nodeName, "Check attr failed."), return ge::GRAPH_FAILED); + + bool isSharedExpert = true; + uint32_t sharedExpertRankNum = tilingData->moeDistributeDispatchInfo.sharedExpertRankNum; + + uint32_t epRankId = tilingData->moeDistributeDispatchInfo.epRankId; + if (epRankId >= sharedExpertRankNum) { // 本卡为moe专家 + isSharedExpert = false; + } + // 检查shape各维度并赋值h,k + OP_TILING_CHECK(CheckTensorShape(context, nodeName, *tilingData, quantMode, isScales, isSharedExpert, + static_cast(localMoeExpertNum)) != ge::GRAPH_SUCCESS, + OP_LOGE(nodeName, "Check tensor shape failed."), return ge::GRAPH_FAILED); + // 校验win区大小 + uint64_t maxWindowSize = Mc2TilingUtils::GetMaxWindowSize(); + uint64_t bs = static_cast(tilingData->moeDistributeDispatchInfo.bs); + uint64_t h = static_cast(tilingData->moeDistributeDispatchInfo.h); + uint64_t epWorldSize = static_cast(tilingData->moeDistributeDispatchInfo.epWorldSize); + uint64_t maxBs = static_cast(tilingData->moeDistributeDispatchInfo.globalBs) / epWorldSize; + uint64_t actualSize = epWorldSize * maxBs * h * FLOAT16_SIZE * BUFF_NUM * static_cast(localMoeExpertNum); + if (actualSize > maxWindowSize) { + OP_LOGE(nodeName, + "HCCL_BUFFSIZE is too SMALL, maxBs = %lu, h = %lu, epWorldSize = %lu, localMoeExpertNum = %u," + "ep_worldsize * maxBs * h * %lu * %lu * localMoeExpertNum = %luMB, HCCL_BUFFSIZE=%luMB.", + maxBs, h, epWorldSize, localMoeExpertNum, FLOAT16_SIZE, BUFF_NUM, actualSize / MB_SIZE + 1UL, + maxWindowSize / MB_SIZE); + return ge::GRAPH_FAILED; + } + tilingData->moeDistributeDispatchInfo.totalWinSize = maxWindowSize; + + OP_TILING_CHECK(SetWorkSpace(context, nodeName) != ge::GRAPH_SUCCESS, + OP_LOGE(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED); + SetHcommCfg(context, *tilingData, groupEp, groupTp); + uint32_t tpWorldSize = tilingData->moeDistributeDispatchInfo.tpWorldSize; + uint64_t tilingKey = INIT_TILINGKEY; + CalTilingKey(tilingKey, isScales, quantMode, tpWorldSize); + OP_LOGD(nodeName, "tilingKey is %lu", tilingKey); + context.SetTilingKey(tilingKey); + uint32_t blockDim = 1U; + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context.GetPlatformInfo()); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + uint64_t ubSize = 0UL; + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); + blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, 0, aivNum); + context.SetBlockDim(blockDim); + tilingData->moeDistributeDispatchInfo.totalUbSize = ubSize; + tilingData->moeDistributeDispatchInfo.aivNum = aivNum; + OP_LOGD(nodeName, "blockDim=%u, aivNum=%u, ubSize=%lu", blockDim, aivNum, ubSize); + PrintTilingDataInfo(nodeName, *tilingData); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus MoeDistributeDispatchA2CheckShapeAndSetTiling(const gert::TilingContext &context, + CamMoeDistributeDispatchA2Info &info) +{ + const char *nodeName = context.GetNodeName(); + OP_LOGI(nodeName, "MoeDistributeDispatchA2 MoeDistributeDispatchA2CheckShapeAndSetTiling."); + const gert::StorageShape *xStorageShape = context.GetInputShape(X_INDEX); + const gert::StorageShape *expertIdStorageShape = context.GetInputShape(EXPERT_IDS_INDEX); + const gert::StorageShape *scalesStorageShape = context.GetOptionalInputShape(SCALES_INDEX); + + OP_TILING_CHECK(xStorageShape == nullptr, OP_LOGE(K_INNER_DEBUG, "xShape is null."), return GRAPH_FAILED); + OP_TILING_CHECK(expertIdStorageShape == nullptr, OP_LOGE(K_INNER_DEBUG, "expertIdShape is null."), + return GRAPH_FAILED); + OP_TILING_CHECK(xStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, + OP_LOGE(K_INNER_DEBUG, "x dims is invalid."), return false); + OP_TILING_CHECK(expertIdStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, + OP_LOGE(K_INNER_DEBUG, "expertId dims is invalid."), return false); + OP_LOGD(nodeName, "X dim0 = %ld", xStorageShape->GetStorageShape().GetDim(0)); + OP_LOGD(nodeName, "X dim1 = %ld", xStorageShape->GetStorageShape().GetDim(1)); + OP_LOGD(nodeName, "expertId dim0 = %ld", expertIdStorageShape->GetStorageShape().GetDim(0)); + OP_LOGD(nodeName, "expertId dim1 = %ld", expertIdStorageShape->GetStorageShape().GetDim(1)); + + uint32_t h = static_cast(xStorageShape->GetStorageShape().GetDim(1)); + uint32_t bs = static_cast(expertIdStorageShape->GetStorageShape().GetDim(0)); + uint32_t k = static_cast(expertIdStorageShape->GetStorageShape().GetDim(1)); + bool isScales = (scalesStorageShape != nullptr); + auto attrs = context.GetAttrs(); + OP_TILING_CHECK(attrs == nullptr, OP_LOGE(K_INNER_DEBUG, "attrs is null."), return ge::GRAPH_FAILED); + auto quantModePtr = attrs->GetAttrPointer(ATTR_QUANT_MODE_INDEX); + OP_TILING_CHECK(h % BLOCK_SIZE_A2 != 0 || h <= 0 || h > MAX_HIDDEN_SIZE_A2, + OP_LOGE(K_INNER_DEBUG, "hiddensize is invalid."), return GRAPH_FAILED); + OP_TILING_CHECK(bs <= 0, OP_LOGE(K_INNER_DEBUG, "batchsize is invalid."), return GRAPH_FAILED); + OP_TILING_CHECK(k <= 0 || k > MAX_K_VALUE_A2, OP_LOGE(K_INNER_DEBUG, "k is invalid."), return GRAPH_FAILED); + OP_TILING_CHECK(*quantModePtr == UNQUANT_MODE && isScales, + OP_LOGE(K_INNER_DEBUG, "scales should be null when quantMode is unQuant."), return GRAPH_FAILED); + + const gert::StorageShape *tokenServerIdxStorageShape = context.GetInputShape(TOKEN_SERVER_IDX_INDEX); + OP_TILING_CHECK(tokenServerIdxStorageShape == nullptr, + OP_LOGE(K_INNER_DEBUG, "tokenServerIdxStorageShape is null."), return GRAPH_FAILED); + const gert::StorageShape *tokenServerCntStorageShape = context.GetInputShape(TOKEN_SERVER_CNT_INDEX); + OP_TILING_CHECK(tokenServerCntStorageShape == nullptr, + OP_LOGE(K_INNER_DEBUG, "tokenServerCntStorageShape is null."), return GRAPH_FAILED); + const gert::StorageShape *epRankTokenCntStorageShape = context.GetInputShape(EP_RANK_TOKEN_CNT_INDEX); + OP_TILING_CHECK(epRankTokenCntStorageShape == nullptr, + OP_LOGE(K_INNER_DEBUG, "epRankTokenCntStorageShape is null."), return GRAPH_FAILED); + const gert::StorageShape *srcOffsetRankTokenIdxStorageShape = + context.GetInputShape(SRC_OFFSET_RANK_TOKEN_IDX_INDEX); + OP_TILING_CHECK(srcOffsetRankTokenIdxStorageShape == nullptr, + OP_LOGE(K_INNER_DEBUG, "srcOffsetRankTokenIdxStorageShape is null."), return GRAPH_FAILED); + const gert::StorageShape *dstOffsetRankTokenIdxStorageShape = + context.GetInputShape(DST_OFFSET_RANK_TOKEN_IDX_INDEX); + OP_TILING_CHECK(dstOffsetRankTokenIdxStorageShape == nullptr, + OP_LOGE(K_INNER_DEBUG, "dstOffsetRankTokenIdxStorageShape is null."), return GRAPH_FAILED); + + info.isQuant = isScales; + info.bs = bs; + info.k = k; + info.h = h; + + OP_LOGD(K_INNER_DEBUG, "isQuant=%d", info.isQuant); + OP_LOGD(K_INNER_DEBUG, "batchSize=%d", info.bs); + OP_LOGD(K_INNER_DEBUG, "k=%d", info.k); + OP_LOGD(K_INNER_DEBUG, "hidenSize=%d", info.h); + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus MoeDistributeDispatchA2CheckAttrAndSetTiling(const gert::TilingContext &context, + CamMoeDistributeDispatchA2Info &info) +{ + auto attrs = context.GetAttrs(); + OP_TILING_CHECK(attrs == nullptr, OP_LOGE(K_INNER_DEBUG, "attrs is null."), return ge::GRAPH_FAILED); + + auto groupEpPtr = attrs->GetAttrPointer(static_cast(ATTR_GROUP_EP_INDEX)); + auto epWorldSizePtr = attrs->GetAttrPointer(ATTR_EP_WORLD_SIZE_INDEX); + auto epRankIdPtr = attrs->GetAttrPointer(ATTR_EP_RANK_ID_INDEX); + auto moeExpertNumPtr = attrs->GetAttrPointer(ATTR_MOE_EXPERT_NUM_INDEX); + auto tpWorldSizePtr = attrs->GetAttrPointer(ATTR_TP_WORLD_SIZE_INDEX); + auto tpRankIdPtr = attrs->GetAttrPointer(ATTR_TP_RANK_ID_INDEX); + auto expertSharedTypePtr = attrs->GetAttrPointer(ATTR_EXPERT_SHARD_TYPE_INDEX); + auto sharedExpertRankNumPtr = attrs->GetAttrPointer(ATTR_SHARED_EXPERT_RANK_NUM_INDEX); + auto quantModePtr = attrs->GetAttrPointer(ATTR_QUANT_MODE_INDEX); + auto globalBsPtr = attrs->GetAttrPointer(ATTR_GLOBAL_BS_INDEX); + auto expertTokenNumsTypePtr = attrs->GetAttrPointer(ATTR_EXPERT_TOKEN_NUMS_TYPE_INDEX); + + const gert::StorageShape *expertIdStorageShape = context.GetInputShape(EXPERT_IDS_INDEX); + OP_TILING_CHECK(expertIdStorageShape == nullptr, OP_LOGE(K_INNER_DEBUG, "expertIdShape is null."), + return GRAPH_FAILED); + int32_t bs = expertIdStorageShape->GetStorageShape().GetDim(0); + + OP_TILING_CHECK(groupEpPtr == nullptr || strlen(groupEpPtr) == 0, OP_LOGE(K_INNER_DEBUG, "groupEp is invalid."), + return GRAPH_FAILED); + OP_TILING_CHECK(epWorldSizePtr == nullptr || *epWorldSizePtr <= 0 || *epWorldSizePtr > MAX_EP_WORLD_SIZE_A2 || + *epWorldSizePtr % RANK_NUM_PER_NODE_A2 != 0, + OP_LOGE(K_INNER_DEBUG, "epWorldSize is invalid."), return GRAPH_FAILED); + OP_TILING_CHECK(epRankIdPtr == nullptr || *epRankIdPtr < 0 || *epRankIdPtr >= *epWorldSizePtr, + OP_LOGE(K_INNER_DEBUG, "epRankId is invalid."), return GRAPH_FAILED); + OP_TILING_CHECK(moeExpertNumPtr == nullptr, OP_LOGE(K_INNER_DEBUG, "moeExpertNumPtr is null."), + return GRAPH_FAILED); + OP_TILING_CHECK( + *moeExpertNumPtr % *epWorldSizePtr != 0 || *moeExpertNumPtr <= 0 || *moeExpertNumPtr > MAX_MOE_EXPERT_NUMS_A2, + OP_LOGE(K_INNER_DEBUG, "moeExpertNum is invalid, only support (0, %d], but got moeExpertNum=%d.", + MAX_MOE_EXPERT_NUMS_A2, *moeExpertNumPtr), + return GRAPH_FAILED); + OP_TILING_CHECK(tpWorldSizePtr == nullptr, OP_LOGE(K_INNER_DEBUG, "tpWorldSize is null."), return GRAPH_FAILED); + OP_TILING_CHECK(tpRankIdPtr == nullptr, OP_LOGE(K_INNER_DEBUG, "tpRankId is null."), return GRAPH_FAILED); + OP_TILING_CHECK(expertSharedTypePtr == nullptr, OP_LOGE(K_INNER_DEBUG, "expertSharedType is null."), + return GRAPH_FAILED); + OP_TILING_CHECK(sharedExpertRankNumPtr == nullptr, OP_LOGE(K_INNER_DEBUG, "sharedExpertRankNum is null."), + return GRAPH_FAILED); + OP_TILING_CHECK(quantModePtr == nullptr || (*quantModePtr != UNQUANT_MODE && *quantModePtr != DYNAMIC_QUANT_MODE), + OP_LOGE(K_INNER_DEBUG, "quantMode is invalid."), return GRAPH_FAILED); + OP_TILING_CHECK(globalBsPtr == nullptr, OP_LOGE(K_INNER_DEBUG, "globalBs is null."), return GRAPH_FAILED); + OP_TILING_CHECK(expertTokenNumsTypePtr == nullptr || *expertTokenNumsTypePtr < 0 || *expertTokenNumsTypePtr > 1, + OP_LOGE(K_INNER_DEBUG, "expertTokenNumsType is invalid. Must be 0 or 1. "), return GRAPH_FAILED); + + info.epWorldSize = *epWorldSizePtr; + info.tpWorldSize = static_cast(0); + info.epRankId = *epRankIdPtr; + info.tpRankId = static_cast(0); + info.expertSharedType = static_cast(0); + info.sharedExpertRankNum = static_cast(0); + info.moeExpertNum = *moeExpertNumPtr; + info.quantMode = *quantModePtr; + info.globalBs = static_cast(*epWorldSizePtr * bs); + info.expertTokenNumsType = *expertTokenNumsTypePtr; + + OP_LOGD(K_INNER_DEBUG, "quantMode=%d", info.quantMode); + OP_LOGD(K_INNER_DEBUG, "globalBs=%d", info.globalBs); + OP_LOGD(K_INNER_DEBUG, "expertTokenNumsType=%d", info.expertTokenNumsType); + OP_LOGD(K_INNER_DEBUG, "expertSharedType=%d", info.expertSharedType); + OP_LOGD(K_INNER_DEBUG, "sharedExpertRankNum=%d", info.sharedExpertRankNum); + OP_LOGD(K_INNER_DEBUG, "moeExpertNum=%d", info.moeExpertNum); + OP_LOGD(K_INNER_DEBUG, "epWorldSize=%d", info.epWorldSize); + OP_LOGD(K_INNER_DEBUG, "tpWorldSize=%d", info.tpWorldSize); + OP_LOGD(K_INNER_DEBUG, "epRankId=%d", info.epRankId); + OP_LOGD(K_INNER_DEBUG, "tpRankId=%d", info.tpRankId); + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus MoeDistributeDispatchA2GetPlatformInfoAndSetTiling(const gert::TilingContext &context, + CamMoeDistributeDispatchA2Info &info) +{ + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context.GetPlatformInfo()); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + uint64_t ubSize = 0U; + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); + info.aivNum = aivNum; + info.totalUbSize = ubSize; + + OP_LOGD(K_INNER_DEBUG, "aivNum=%d", info.aivNum); + OP_LOGD(K_INNER_DEBUG, "ubSize=%lu", info.totalUbSize); + + return ge::GRAPH_SUCCESS; +} + +static uint64_t MoeDistributeDispatchA2CalcTilingKey(const gert::TilingContext &context) +{ + uint64_t tilingKey = TILING_KEY_BASE_A2 + INIT_TILINGKEY; + std::string hcclIntraPcieEnableStr; + std::string hcclIntraRoceEnableStr; + const char *hcclIntraPcieEnable = getenv("HCCL_INTRA_PCIE_ENABLE"); + if (hcclIntraPcieEnable != nullptr) { + hcclIntraPcieEnableStr = hcclIntraPcieEnable; + } + const char *hcclIntraRoceEnable = getenv("HCCL_INTRA_ROCE_ENABLE"); + if (hcclIntraRoceEnable != nullptr) { + hcclIntraRoceEnableStr = hcclIntraRoceEnable; + } + + if (hcclIntraPcieEnableStr.empty() || hcclIntraRoceEnableStr.empty()) { + OP_LOGD(K_INNER_DEBUG, "ENV HCCL_INTRA_PCIE_ENABLE or HCCL_INTRA_ROCE_ENABLE don't set"); + } else if (hcclIntraPcieEnableStr == "1" && hcclIntraRoceEnableStr == "0") { + tilingKey += TILING_KEY_LAYERED_COMM_A2; + OP_LOGD(K_INNER_DEBUG, "ENV HCCL_INTRA_PCIE_ENABLE = 1 and HCCL_INTRA_ROCE_ENABLE = 0, use layered solution."); + } else { + OP_LOGD(K_INNER_DEBUG, "ENV HCCL_INTRA_PCIE_ENABLE != 1 or HCCL_INTRA_ROCE_ENABLE != 0, use default solution."); + } + + auto attrs = context.GetAttrs(); + const char *nodeName = context.GetNodeName(); + OP_TILING_CHECK(attrs == nullptr, OP_LOGE(nodeName, "attrs is null."), return 0); + auto quantModePtr = attrs->GetAttrPointer(ATTR_QUANT_MODE_INDEX); + tilingKey += static_cast(*quantModePtr); + + const gert::StorageShape *scalesStorageShape = context.GetOptionalInputShape(SCALES_INDEX); + bool isScales = (scalesStorageShape != nullptr); + tilingKey += static_cast((isScales ? SCALES_TILING_KEY : 0)); + + OP_LOGD(K_INNER_DEBUG, "tilingKey=%lu", tilingKey); + + return tilingKey; +} + +static ge::graphStatus MoeDistributeDispatchA2TilingFuncImpl(gert::TilingContext &context) +{ + OP_LOGD(nodeName, "start MoeDistributeDispatchA2TilingFuncImpl func."); + const char *nodeName = context.GetNodeName(); + OP_LOGI(nodeName, "Enter MoeDistributeDispatchA2 tiling func."); + + // 1. tilingData + CamMoeDistributeDispatchA2TilingData *tilingData = context.GetTilingData(); + OP_TILING_CHECK(tilingData == nullptr, VECTOR_INNER_ERR_REPORT_TILIING(nodeName, "tilingData is nullptr."), + return ge::GRAPH_FAILED); + OP_LOGI(nodeName, "MoeDistributeDispatchA2 get tilingData."); + CamMoeDistributeDispatchA2Info &info = tilingData->moeDistributeDispatchInfo; + OP_LOGI(nodeName, "MoeDistributeDispatchA2 get tilingData info."); + + OP_TILING_CHECK( + MoeDistributeDispatchA2CheckShapeAndSetTiling(context, info) != ge::GRAPH_SUCCESS, + VECTOR_INNER_ERR_REPORT_TILIING(context.GetNodeName(), "MoeDistributeDispatchA2 CheckShapeAndSetTiling Failed"), + return ge::GRAPH_FAILED); + OP_TILING_CHECK( + MoeDistributeDispatchA2CheckAttrAndSetTiling(context, info) != ge::GRAPH_SUCCESS, + VECTOR_INNER_ERR_REPORT_TILIING(context.GetNodeName(), "MoeDistributeDispatchA2 CheckAttrAndSetTiling Failed"), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(MoeDistributeDispatchA2GetPlatformInfoAndSetTiling(context, info) != ge::GRAPH_SUCCESS, + VECTOR_INNER_ERR_REPORT_TILIING(context.GetNodeName(), + "MoeDistributeDispatchA2 GetPlatformInfoAndSetTiling Failed"), + return ge::GRAPH_FAILED); + + uint32_t blockDim = 1U; + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context.GetPlatformInfo()); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, 0, aivNum); + context.SetBlockDim(blockDim); + + uint64_t tilingKey = MoeDistributeDispatchA2CalcTilingKey(context); + context.SetTilingKey(tilingKey); + if ((tilingKey & TILING_KEY_LAYERED_COMM_A2) != 0) { + OP_TILING_CHECK(info.k != 8, OP_LOGE(nodeName, "As layered, K must be 8."), return ge::GRAPH_FAILED); + } + // 2. workspace + size_t *workSpaces = context.GetWorkspaceSizes(1); + OP_TILING_CHECK(workSpaces == nullptr, VECTOR_INNER_ERR_REPORT_TILIING(nodeName, "workSpaces is nullptr."), + return ge::GRAPH_FAILED); + // wyl second USER_WORKSPACE_A2 is for dumpprof + workSpaces[0] = SYSTEM_NEED_WORKSPACE + USER_WORKSPACE_A2 + USER_WORKSPACE_A2; + + // 3. communication + auto attrs = context.GetAttrs(); + OP_TILING_CHECK(attrs == nullptr, OP_LOGE(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED); + auto group = attrs->GetAttrPointer(static_cast(ATTR_GROUP_EP_INDEX)); + uint32_t opType = 18; // batch write=18, + std::string algConfig = "MultiPut=level0:fullmesh"; + AscendC::Mc2CcTilingConfig mc2CcTilingConfig(group, opType, algConfig); + mc2CcTilingConfig.GetTiling(tilingData->mc2InitTiling); + mc2CcTilingConfig.GetTiling(tilingData->mc2CcTiling); + + OP_LOGD(nodeName, "Leave MoeDistributeDispatchA2 tiling func."); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus DispatchNormalA2TilingFunc(gert::TilingContext *context) +{ + ge::graphStatus ret = MoeDistributeDispatchA2TilingFuncImpl(*context); + return ret; +} + +struct DispatchNormalA2CompileInfo {}; +ge::graphStatus TilingParseForDispatchNormalA2(gert::TilingParseContext *context) +{ + (void)context; + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(DispatchNormalA2) + .Tiling(DispatchNormalA2TilingFunc) + .TilingParse(TilingParseForDispatchNormalA2); +} // namespace optiling diff --git a/csrc/deepep/ops2/op_host/error_log.h b/csrc/deepep/ops2/op_host/error_log.h new file mode 100644 index 00000000..84258321 --- /dev/null +++ b/csrc/deepep/ops2/op_host/error_log.h @@ -0,0 +1,48 @@ +#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ +#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ + +#include +#include "toolchain/slog.h" + +#define OP_LOGI(opname, ...) +#define OP_LOGW(opname, ...) \ + do { \ + printf("[WARN][%s] ", (opname)); \ + printf(__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGE_WITHOUT_REPORT(opname, ...) \ + do { \ + printf("[ERRORx][%s] ", (opname)); \ + printf(__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +#define OP_LOGE(opname, ...) \ + do { \ + printf("[ERROR][%s] ", (opname)); \ + printf(__VA_ARGS__); \ + printf("\n"); \ + } while (0) + +// #define OP_LOGD(opname, ...) printf("[DEBUG]" __VA_ARGS__); printf("\n"); +#define OP_LOGD(opname, ...) + +namespace optiling { + +#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \ + do { \ + OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \ + } while (0) + +#define OP_TILING_CHECK(cond, log_func, expr) \ + do { \ + if (cond) { \ + log_func; \ + expr; \ + } \ + } while (0) +} // namespace optiling + +#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ diff --git a/csrc/deepep/ops2/op_host/moe_distribute_combine_a2.cpp b/csrc/deepep/ops2/op_host/moe_distribute_combine_a2.cpp new file mode 100644 index 00000000..99e1e18c --- /dev/null +++ b/csrc/deepep/ops2/op_host/moe_distribute_combine_a2.cpp @@ -0,0 +1,133 @@ +#include "register/op_def_registry.h" + +namespace ops { +class MoeDistributeCombineA2 : public OpDef +{ +public: + explicit MoeDistributeCombineA2(const char *name) : OpDef(name) + { + this->Input("expand_x") + .ParamType(REQUIRED) + .DataType({ge::DT_BF16, ge::DT_FLOAT16, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("expert_ids") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("expand_idx") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("ep_send_counts") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("expert_scales") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("tp_send_counts") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("x_active_mask") + .ParamType(OPTIONAL) + .DataType({ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("activation_scale") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("weight_scale") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("group_list") + .ParamType(OPTIONAL) + .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("expand_scales") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("offsetInner") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("offsetOuter") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + this->Input("countOuter") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .AutoContiguous(); + + this->Output("x") + .ParamType(REQUIRED) + .DataType({ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + + this->Attr("group_ep").AttrType(REQUIRED).String(); + this->Attr("ep_world_size").AttrType(REQUIRED).Int(); + this->Attr("ep_rank_id").AttrType(REQUIRED).Int(); + this->Attr("moe_expert_num").AttrType(REQUIRED).Int(); + this->Attr("group_tp").AttrType(OPTIONAL).String(""); + this->Attr("tp_world_size").AttrType(OPTIONAL).Int(0); + this->Attr("tp_rank_id").AttrType(OPTIONAL).Int(0); + this->Attr("expert_shard_type").AttrType(OPTIONAL).Int(0); + this->Attr("shared_expert_num").AttrType(OPTIONAL).Int(1); + this->Attr("shared_expert_rank_num").AttrType(OPTIONAL).Int(0); + this->Attr("global_bs").AttrType(OPTIONAL).Int(0); + this->Attr("out_dtype").AttrType(OPTIONAL).Int(0); + this->Attr("comm_quant_mode").AttrType(OPTIONAL).Int(0); + this->Attr("group_list_type").AttrType(OPTIONAL).Int(0); + + OpAICoreConfig aicore_config_A2; + aicore_config_A2.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true) + .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") + .ExtendCfgInfo("jitCompile.flag", "static_false") + .ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel"); + + this->AICore().AddConfig("ascend910b", aicore_config_A2); + this->MC2().HcclGroup("group_ep"); // A2不支持添加多个通信域,会出现卡死 + } +}; + +OP_ADD(MoeDistributeCombineA2); + +} // namespace ops diff --git a/csrc/deepep/ops2/op_host/moe_distribute_combine_a2_tiling.cpp b/csrc/deepep/ops2/op_host/moe_distribute_combine_a2_tiling.cpp new file mode 100644 index 00000000..b70825d8 --- /dev/null +++ b/csrc/deepep/ops2/op_host/moe_distribute_combine_a2_tiling.cpp @@ -0,0 +1,316 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "register/tilingdata_base.h" +#include "tiling/tiling_api.h" +#include "error_log.h" +#include "graph/utils/type_utils.h" +// #include "hcom_topo_info.h" +// #include "error/ops_error.h" +#include "register/op_def_registry.h" +#include "../op_kernel/moe_distribute_combine_a2_tiling.h" +#include "tiling_args.h" + +#define OPS_CHECK OP_TILING_CHECK +#define OPS_LOG_E OP_LOGE +#define OPS_LOG_I OP_LOGI +#define OPS_LOG_D OP_LOGD +#define OPS_REPORT_VECTOR_INNER_ERR VECTOR_INNER_ERR_REPORT_TILIING + +using namespace AscendC; +using namespace ge; + +namespace { +constexpr uint32_t EXPAND_X_INDEX = 0; +constexpr uint32_t EXPERT_IDS_INDEX = 1; +constexpr uint32_t EXPAND_IDX_INDEX = 2; +constexpr uint32_t EP_SEND_COUNTS_INDEX = 3; +constexpr uint32_t EXPERT_SCALES_INDEX = 4; +constexpr uint32_t TP_SEND_COUNTS_INDEX = 5; +constexpr uint32_t X_ACTIVE_MASK_INDEX = 6; +constexpr uint32_t OUTPUT_X_INDEX = 0; + +constexpr uint32_t ATTR_GROUP_EP_INDEX = 0; +constexpr uint32_t ATTR_EP_WORLD_SIZE_INDEX = 1; +constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2; +constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 3; +constexpr uint32_t ATTR_GROUP_TP_INDEX = 4; +constexpr uint32_t ATTR_TP_WORLD_SIZE_INDEX = 5; +constexpr uint32_t ATTR_TP_RANK_ID_INDEX = 6; +constexpr uint32_t ATTR_EXPERT_SHARD_TYPE_INDEX = 7; +constexpr uint32_t ATTR_SHARED_EXPERT_NUM_INDEX = 8; +constexpr uint32_t ATTR_SHARED_EXPERT_RANK_NUM_INDEX = 9; +constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 10; + +constexpr uint32_t TWO_DIMS = 2U; +constexpr uint32_t ONE_DIM = 1U; +constexpr uint32_t EXPAND_IDX_DIMS = 1U; +constexpr uint64_t INIT_TILINGKEY_TP_2 = 1100UL; +constexpr uint64_t INIT_TILINGKEY_TP_1 = 1000UL; +constexpr uint64_t TILING_KEY_BASE_A2 = 2000UL; +constexpr uint64_t TILING_KEY_LAYERED_COMM_A2 = 3000UL; +constexpr uint32_t ARR_LENGTH = 128U; +constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8U; // numeric representation of AlltoAll +constexpr uint32_t OP_TYPE_REDUCE_SCATTER = 7U; // numeric representation of AlltoAll + +constexpr int32_t MAX_EP_WORLD_SIZE_A2 = 256; +constexpr int32_t MAX_MOE_EXPERT_NUMS_A2 = 512; +constexpr int32_t MAX_HIDDEN_SIZE_A2 = 7168; +constexpr uint32_t MAX_BATCH_SIZE_LAYERED_A2 = 128; +constexpr uint32_t MAX_BATCH_SIZE_A2 = 256; +constexpr uint32_t RANK_NUM_PER_NODE_A2 = 8; +constexpr uint32_t BLOCK_SIZE_A2 = 32; +constexpr uint32_t MAX_K_VALUE_A2 = 16; +constexpr uint32_t MIN_K_VALUE_A2 = 2; +const char *K_INNER_DEBUG = "MoeDistributeCombine Tiling Debug"; +const size_t MAX_GROUP_NAME_LENGTH = 128UL; +const int64_t MAX_EP_WORLD_SIZE = 288; +const int64_t MAX_TP_WORLD_SIZE = 2; +const int64_t BS_UPPER_BOUND = 512; + +constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024; +constexpr int32_t HCCL_BUFFER_SIZE_DEFAULT = 200 * 1024 * 1024; // Bytes +constexpr uint32_t VERSION_2 = 2; +constexpr uint32_t HCOMMCNT_2 = 2; +constexpr int64_t MOE_EXPERT_MAX_NUM = 512; +constexpr int64_t K_MAX = 8; +constexpr uint64_t MB_SIZE = 1024UL * 1024UL; +} // namespace + +namespace optiling { +static ge::graphStatus MoeDistributeCombineA2CheckAttrAndSetTiling(gert::TilingContext *context, + MoeDistributeCombineA2Info &info) +{ + auto attrs = context->GetAttrs(); + OPS_CHECK(attrs == nullptr, OPS_LOG_E(K_INNER_DEBUG, "attrs is null."), return ge::GRAPH_FAILED); + + auto epWorldSizePtr = attrs->GetAttrPointer(ATTR_EP_WORLD_SIZE_INDEX); + auto epRankIdPtr = attrs->GetAttrPointer(ATTR_EP_RANK_ID_INDEX); + auto moeExpertNumPtr = attrs->GetAttrPointer(ATTR_MOE_EXPERT_NUM_INDEX); + auto tpWorldSizePtr = attrs->GetAttrPointer(ATTR_TP_WORLD_SIZE_INDEX); + auto tpRankIdPtr = attrs->GetAttrPointer(ATTR_TP_RANK_ID_INDEX); + auto expertSharedTypePtr = attrs->GetAttrPointer(ATTR_EXPERT_SHARD_TYPE_INDEX); + auto sharedExpertRankNumPtr = attrs->GetAttrPointer(ATTR_SHARED_EXPERT_RANK_NUM_INDEX); + auto globalBsPtr = attrs->GetAttrPointer(ATTR_GLOBAL_BS_INDEX); + + OPS_CHECK(epWorldSizePtr == nullptr || *epWorldSizePtr <= 0 || *epWorldSizePtr > MAX_EP_WORLD_SIZE_A2 || + *epWorldSizePtr % RANK_NUM_PER_NODE_A2 != 0, + OPS_LOG_E(K_INNER_DEBUG, "epWorldSize is invalid."), return GRAPH_FAILED); + OPS_CHECK(epRankIdPtr == nullptr || *epRankIdPtr < 0 || *epRankIdPtr >= *epWorldSizePtr, + OPS_LOG_E(K_INNER_DEBUG, "epRankId is invalid."), return GRAPH_FAILED); + OPS_CHECK(moeExpertNumPtr == nullptr || *moeExpertNumPtr <= 0 || *moeExpertNumPtr > MAX_MOE_EXPERT_NUMS_A2 || + *moeExpertNumPtr % *epWorldSizePtr != 0, + OPS_LOG_E(K_INNER_DEBUG, "moeExpertNum is invalid."), return GRAPH_FAILED); + OPS_CHECK(tpWorldSizePtr == nullptr, OPS_LOG_E(K_INNER_DEBUG, "tpWorldSize is null."), return GRAPH_FAILED); + OPS_CHECK(tpRankIdPtr == nullptr, OPS_LOG_E(K_INNER_DEBUG, "tpRankId is null."), return GRAPH_FAILED); + OPS_CHECK(expertSharedTypePtr == nullptr, OPS_LOG_E(K_INNER_DEBUG, "expertSharedType is null."), + return GRAPH_FAILED); + OPS_CHECK(sharedExpertRankNumPtr == nullptr, OPS_LOG_E(K_INNER_DEBUG, "sharedExpertRankNum is null."), + return GRAPH_FAILED); + OPS_CHECK(globalBsPtr == nullptr, OPS_LOG_E(K_INNER_DEBUG, "globalBs is null."), return GRAPH_FAILED); + + const gert::StorageShape *expertIdStorageShape = context->GetInputShape(EXPERT_IDS_INDEX); + OPS_CHECK(expertIdStorageShape == nullptr, OPS_LOG_E(K_INNER_DEBUG, "xShape is null."), return false); + int32_t globalBs = *epWorldSizePtr * expertIdStorageShape->GetStorageShape().GetDim(0); + + info.epWorldSize = *epWorldSizePtr; + info.tpWorldSize = static_cast(0); + info.epRankId = *epRankIdPtr; + info.tpRankId = static_cast(0); + info.expertSharedType = static_cast(0); + info.sharedExpertRankNum = static_cast(0); + info.moeExpertNum = *moeExpertNumPtr; + if (*globalBsPtr == 0) { + info.globalBs = static_cast(globalBs); + } else { + info.globalBs = *globalBsPtr; + } + + OPS_LOG_D(K_INNER_DEBUG, "epWorldSize=%u", info.epWorldSize); + OPS_LOG_D(K_INNER_DEBUG, "tpWorldSize=%u", info.tpWorldSize); + OPS_LOG_D(K_INNER_DEBUG, "epRankId=%u", info.epRankId); + OPS_LOG_D(K_INNER_DEBUG, "tpRankId=%u", info.tpRankId); + OPS_LOG_D(K_INNER_DEBUG, "expertSharedType=%u", info.expertSharedType); + OPS_LOG_D(K_INNER_DEBUG, "sharedExpertRankNum=%u", info.sharedExpertRankNum); + OPS_LOG_D(K_INNER_DEBUG, "moeExpertNum=%u", info.moeExpertNum); + OPS_LOG_D(K_INNER_DEBUG, "globalBs=%u", info.globalBs); + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus MoeDistributeCombineA2CheckShapeAndSetTiling(gert::TilingContext *context, + MoeDistributeCombineA2Info &info, + const bool isLayered) +{ + const gert::StorageShape *expandXStorageShape = context->GetInputShape(EXPAND_X_INDEX); + const gert::StorageShape *expertIdStorageShape = context->GetInputShape(EXPERT_IDS_INDEX); + OPS_CHECK(expandXStorageShape == nullptr, OPS_LOG_E(K_INNER_DEBUG, "expandXShape is null."), return GRAPH_FAILED); + OPS_CHECK(expertIdStorageShape == nullptr, OPS_LOG_E(K_INNER_DEBUG, "expertIdShape is null."), return GRAPH_FAILED); + + OPS_CHECK(expandXStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, + OPS_LOG_E(K_INNER_DEBUG, "expandXshape is invalid"), return GRAPH_FAILED); + uint32_t h = expandXStorageShape->GetStorageShape().GetDim(1); + OPS_CHECK(h <= 0 || h > MAX_HIDDEN_SIZE_A2 || h % BLOCK_SIZE_A2 != 0, + OPS_LOG_E(K_INNER_DEBUG, "hiddensize is invalid."), return GRAPH_FAILED); + OPS_CHECK(expertIdStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, + OPS_LOG_E(K_INNER_DEBUG, "expertIdshape is invalid"), return GRAPH_FAILED); + uint32_t bs = expertIdStorageShape->GetStorageShape().GetDim(0); + OPS_CHECK(bs <= 0, OPS_LOG_E(K_INNER_DEBUG, "batchsize is invalid."), return GRAPH_FAILED); + uint32_t k = expertIdStorageShape->GetStorageShape().GetDim(1); + OPS_CHECK(k < MIN_K_VALUE_A2 || k > MAX_K_VALUE_A2, OPS_LOG_E(K_INNER_DEBUG, "k is invalid."), return GRAPH_FAILED); + const uint32_t maxBatchSize = isLayered ? MAX_BATCH_SIZE_LAYERED_A2 : MAX_BATCH_SIZE_A2; + OPS_CHECK(bs > maxBatchSize, OPS_LOG_E(K_INNER_DEBUG, "Batchsize must be smaller than %u.", maxBatchSize), + return ge::GRAPH_FAILED); + info.bs = bs; + info.k = k; + info.h = h; + + OPS_LOG_D(K_INNER_DEBUG, "batchSize=%u", bs); + OPS_LOG_D(K_INNER_DEBUG, "k=%u", k); + OPS_LOG_D(K_INNER_DEBUG, "hidenSize=%u", h); + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus MoeDistributeCombineA2GetPlatformInfoAndSetTiling(gert::TilingContext *context, + MoeDistributeCombineA2Info &info) +{ + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + uint64_t ubSize = 0U; + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); + + info.aivNum = aivNum; + info.totalUbSize = ubSize; + + OPS_LOG_D(K_INNER_DEBUG, "aivNum=%u", info.aivNum); + OPS_LOG_D(K_INNER_DEBUG, "ubSize=%lu", info.totalUbSize); + + return ge::GRAPH_SUCCESS; +} + +static bool MoeDistributeCombineA2IsLayered() +{ + const char *hcclIntraPcieEnable = getenv("HCCL_INTRA_PCIE_ENABLE"); + const char *hcclIntraRoceEnable = getenv("HCCL_INTRA_ROCE_ENABLE"); + + if (hcclIntraPcieEnable == nullptr || hcclIntraRoceEnable == nullptr) { + OPS_LOG_D(K_INNER_DEBUG, "ENV HCCL_INTRA_PCIE_ENABLE or HCCL_INTRA_ROCE_ENABLE don't set"); + return false; + } + if (strcmp(hcclIntraPcieEnable, "1") == 0 && strcmp(hcclIntraRoceEnable, "0") == 0) { + OPS_LOG_D(K_INNER_DEBUG, + "ENV HCCL_INTRA_PCIE_ENABLE = 1 and HCCL_INTRA_ROCE_ENABLE = 0, use layered solution."); + return true; + } + OPS_LOG_D(K_INNER_DEBUG, "ENV HCCL_INTRA_PCIE_ENABLE != 1 or HCCL_INTRA_ROCE_ENABLE != 0, use default solution."); + return false; +} + +static uint64_t MoeDistributeCombineA2CalcTilingKey(gert::TilingContext *context, const bool isLayered) +{ + const char *nodeName = context->GetNodeName(); + OPS_LOG_I(nodeName, "Enter MoeDistributeCombineA2 calc tiling func."); + + uint64_t tilingKey = TILING_KEY_BASE_A2; + + if (isLayered) { + tilingKey = TILING_KEY_LAYERED_COMM_A2; + } + + OPS_LOG_D(K_INNER_DEBUG, "tilingKey=%lu", tilingKey); + + return tilingKey; +} + +static ge::graphStatus MoeDistributeCombineA2TilingFuncImpl(gert::TilingContext *context) +{ + const char *nodeName = context->GetNodeName(); + OPS_LOG_I(nodeName, "Enter MoeDistributeCombineA2 tiling func."); + + // tilingData + MoeDistributeCombineA2TilingData *tilingData = context->GetTilingData(); + OPS_CHECK(tilingData == nullptr, OPS_REPORT_VECTOR_INNER_ERR(nodeName, "tilingData is nullptr."), + return ge::GRAPH_FAILED); + OPS_LOG_I(nodeName, "MoeDistributeCombineA2 get tilingData."); + MoeDistributeCombineA2Info &info = tilingData->moeDistributeCombineInfo; + + bool isLayered = MoeDistributeCombineA2IsLayered(); + OPS_CHECK( + MoeDistributeCombineA2CheckShapeAndSetTiling(context, info, isLayered) != ge::GRAPH_SUCCESS, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "MoeDistributeCombineA2 CheckShapeAndSetTiling Failed"), + return ge::GRAPH_FAILED); + OPS_CHECK( + MoeDistributeCombineA2CheckAttrAndSetTiling(context, info) != ge::GRAPH_SUCCESS, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), "MoeDistributeCombineA2 CheckAttrAndSetTiling Failed"), + return ge::GRAPH_FAILED); + OPS_CHECK(MoeDistributeCombineA2GetPlatformInfoAndSetTiling(context, info) != ge::GRAPH_SUCCESS, + OPS_REPORT_VECTOR_INNER_ERR(context->GetNodeName(), + "MoeDistributeCombineA2 GetPlatformInfoAndSetTiling Failed"), + return ge::GRAPH_FAILED); + + uint32_t blockDim = 1U; + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + blockDim = ascendcPlatform.CalcTschBlockDim(aivNum, 0, aivNum); + context->SetBlockDim(blockDim); + + uint64_t tilingKey = MoeDistributeCombineA2CalcTilingKey(context, isLayered); + context->SetTilingKey(tilingKey); + // 2. workspace + size_t *workSpaces = context->GetWorkspaceSizes(1); + OPS_CHECK(workSpaces == nullptr, OPS_REPORT_VECTOR_INNER_ERR(nodeName, "workSpaces is nullptr."), + return ge::GRAPH_FAILED); + uint32_t userWorkspaceSize = static_cast(info.moeExpertNum) * sizeof(uint32_t) * 2; + workSpaces[0] = SYSTEM_NEED_WORKSPACE + userWorkspaceSize; + + // 3. communication + auto attrs = context->GetAttrs(); + auto group = attrs->GetAttrPointer(static_cast(ATTR_GROUP_EP_INDEX)); + uint32_t opType = 18; // batch write=18, + std::string algConfig = "MultiPut=level0:fullmesh"; + AscendC::Mc2CcTilingConfig mc2CcTilingConfig(group, opType, algConfig); + mc2CcTilingConfig.GetTiling(tilingData->mc2InitTiling); + mc2CcTilingConfig.GetTiling(tilingData->mc2CcTiling); + + OPS_LOG_I(nodeName, "Leave MoeDistributeCombineA2 tiling func."); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus MoeDistributeCombineTilingFunc(gert::TilingContext *context) +{ + // 不支持 expandX数据类型为int32 type + auto expandXDesc = context->GetInputDesc(EXPAND_X_INDEX); + const char *nodeName = context->GetNodeName(); + OPS_CHECK(expandXDesc == nullptr, OPS_LOG_E(nodeName, "expandxDesc is null."), return ge::GRAPH_FAILED); + // 检查expandX数据类型为DT_INT32 + OPS_CHECK((expandXDesc->GetDataType() == ge::DT_INT32), + OPS_LOG_E(nodeName, "expandX dataType is invalid, dataType should be bf16 or float16, but is %d", + static_cast(expandXDesc->GetDataType())), + return ge::GRAPH_FAILED); + + return MoeDistributeCombineA2TilingFuncImpl(context); +} + +struct MoeDistributeCombineCompileInfo {}; +ge::graphStatus TilingParseForMoeDistributeCombine(gert::TilingParseContext *context) +{ + (void)context; + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(MoeDistributeCombineA2) + .Tiling(MoeDistributeCombineTilingFunc) + .TilingParse(TilingParseForMoeDistributeCombine); +} // namespace optiling diff --git a/csrc/deepep/ops2/op_host/notify_dispatch_a2.cpp b/csrc/deepep/ops2/op_host/notify_dispatch_a2.cpp new file mode 100644 index 00000000..25cd7897 --- /dev/null +++ b/csrc/deepep/ops2/op_host/notify_dispatch_a2.cpp @@ -0,0 +1,113 @@ +#include "register/op_def_registry.h" + +namespace ops { +class NotifyDispatchA2 : public OpDef +{ +public: + explicit NotifyDispatchA2(const char *name) : OpDef(name) + { + this->Input("sendData") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("tokenPerExpertData") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("tmpData") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("sendDataOffset") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("recvData") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("tokenServerIdx") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("tokenUniquePerServer") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("epRankTokenCnt") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("localEpTokenCnt") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("srcOffsetRankTokenIdx") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("dstOffsetRankTokenIdx") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("offsetInner") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("countOuter") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("expandIdx") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + + this->Attr("sendCount").Int(); + this->Attr("num_tokens").Int(); + this->Attr("topk_num").Int(); + this->Attr("num_experts").Int(); + this->Attr("comm_group").String(); + this->Attr("rank_size").Int(); + this->Attr("rank_id").Int(); + this->Attr("local_rank_size").Int(); + this->Attr("local_rank_id").Int(); + + OpAICoreConfig aicore_config_base; + aicore_config_base.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(true) + .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") + .ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel"); + + OpAICoreConfig aicore_config_A2 = aicore_config_base; + aicore_config_A2.ExtendCfgInfo("jitCompile.flag", "static_false"); + + OpAICoreConfig aicore_config = aicore_config_base; + aicore_config.ExtendCfgInfo("jitCompile.flag", "static_true"); + + this->AICore().AddConfig("ascend910_93", aicore_config); + this->AICore().AddConfig("ascend910b", aicore_config_A2); + this->MC2().HcclGroup("comm_group"); + } +}; + +OP_ADD(NotifyDispatchA2); +} // namespace ops diff --git a/csrc/deepep/ops2/op_host/notify_dispatch_tiling_a2.cc b/csrc/deepep/ops2/op_host/notify_dispatch_tiling_a2.cc new file mode 100644 index 00000000..23e63a4a --- /dev/null +++ b/csrc/deepep/ops2/op_host/notify_dispatch_tiling_a2.cc @@ -0,0 +1,437 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "error_log.h" +#include "graph/utils/type_utils.h" +#include "register/op_def_registry.h" +#include "../op_kernel/notify_dispatch_tiling_a2.h" +#include "tiling/platform/platform_ascendc.h" +#include "tiling/hccl/hccl_tiling.h" +#include "experiment/platform/platform/platform_infos_def.h" + +using namespace ge; +namespace { +class Mc2TilingUtils +{ +public: +#define HCCL_BUFFSIZE "HCCL_BUFFSIZE" + static uint64_t GetMaxWindowSize() + { + uint16_t defaultWindowSize = 200; + if (getenv(HCCL_BUFFSIZE) == nullptr) { + OP_LOGD("", "Env HCCL_BUFFSIZE don't set"); + } else { + try { + std::string envStr(getenv(HCCL_BUFFSIZE)); + defaultWindowSize = std::stoi(envStr); + } catch (const std::invalid_argument &ia) { + OP_LOGE("", "Invalid argument when parsing HCCL_BUFFSIZE: %s", ia.what()); + } catch (const std::out_of_range &oor) { + OP_LOGE("", "Out of range when parsing HCCL_BUFFSIZE: %s", oor.what()); + } + } + const uint64_t maxWindowSize = static_cast(defaultWindowSize) * 1024UL * 1024UL; + OP_LOGI("", "Get maxWindowSize is %lu", maxWindowSize); + return maxWindowSize; + } +}; +constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8U; // numeric representation of AlltoAll + +constexpr uint32_t INPUT_SEND_DATA_INDEX = 0; +constexpr uint32_t INPUT_TOKEN_PER_EXPERT_INDEX = 1; +constexpr uint32_t INPUT_TMP_DATA_INDEX = 2; + +constexpr uint32_t OUTPUT_SEND_DATA_OFFSET_INDEX = 0; +constexpr uint32_t OUTPUT_RECV_DATA_INDEX = 1; +constexpr uint32_t OUTPUT_TOKEN_SERVER_IDX_INDEX = 2; +constexpr uint32_t OUTPUT_TOKEN_UNIQUE_PER_SERVER_INDEX = 3; +constexpr uint32_t OUTPUT_EP_RANK_TOKEN_CNT_INDEX = 4; +constexpr uint32_t OUTPUT_LOCAL_EP_TOKEN_CNT_INDEX = 5; +constexpr uint32_t OUTPUT_SRC_OFFSET_RANK_TOKEN_INDEX = 6; +constexpr uint32_t OUTPUT_DST_OFFSET_RANK_TOKEN_INDEX = 7; +constexpr uint32_t OUTPUT_OFFSET_INNER_INDEX = 8; +constexpr uint32_t OUTPUT_COUNT_OUTER_INDEX = 9; +constexpr uint32_t OUTPUT_EXPAND_IDX_INDEX = 10; + +constexpr uint32_t ATTR_SEND_COUNT_INDEX = 0; +constexpr uint32_t ATTR_NUM_TOKENS_INDEX = 1; +constexpr uint32_t ATTR_TOPK_NUM_INDEX = 2; +constexpr uint32_t ATTR_NUM_EXPERTS_INDEX = 3; +constexpr uint32_t ATTR_COMM_GROUP_INDEX = 4; +constexpr uint32_t ATTR_RANK_SIZE_INDEX = 5; +constexpr uint32_t ATTR_RANK_ID_INDEX = 6; +constexpr uint32_t ATTR_LOCAL_RANK_SIZE_INDEX = 7; +constexpr uint32_t ATTR_LOCAL_RANK_ID_INDEX = 8; + +const size_t MAX_GROUP_NAME_LENGTH = 128UL; +const int64_t MAX_COMM_WORLD_SIZE = 384; + +constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024; +constexpr uint32_t KERNEL_USE_WORKSPACE = 1 * 1024 * 1024; +constexpr uint32_t KERNEL_A2_ARG_SIZE = 16 * 1024 * 1024; +constexpr int32_t HCCL_BUFFER_SIZE_DEFAULT = 200 * 1024 * 1024; // Bytes +constexpr uint64_t MB_SIZE = 1024UL * 1024UL; + +constexpr static int TILING_KEY_FLOAT16 = 20; +constexpr static int TILING_KEY_BFLOAT16 = 21; +constexpr static int TILING_KEY_FLOAT = 22; +constexpr static int TILING_KEY_INT = 23; +constexpr static int TILING_KEY_A2_TYPE = 100; + +constexpr static int ALL_TO_ALL_CORE_NUM = 32; +} // namespace + +namespace optiling { +static void PrintTilingDataInfo(const char *nodeName, NotifyDispatchA2TilingData &tilingData) +{ + OP_LOGD(nodeName, "rankSize is %u.", tilingData.notifyDispatchInfoA2.rankSize); + OP_LOGD(nodeName, "rankId is %u.", tilingData.notifyDispatchInfoA2.rankId); + OP_LOGD(nodeName, "localRankSize is %u.", tilingData.notifyDispatchInfoA2.localRankSize); + OP_LOGD(nodeName, "localRankId is %u.", tilingData.notifyDispatchInfoA2.localRankId); + OP_LOGD(nodeName, "sendCount is %u.", tilingData.notifyDispatchInfoA2.sendCount); + OP_LOGD(nodeName, "numTokens is %u.", tilingData.notifyDispatchInfoA2.numTokens); + OP_LOGD(nodeName, "topkNum is %u.", tilingData.notifyDispatchInfoA2.topkNum); + OP_LOGD(nodeName, "numExperts is %u.", tilingData.notifyDispatchInfoA2.numExperts); + OP_LOGD(nodeName, "aivNum is %u.", tilingData.notifyDispatchInfoA2.aivNum); + OP_LOGD(nodeName, "totalUbSize is %lu.", tilingData.notifyDispatchInfoA2.totalUbSize); +} + +static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName, + NotifyDispatchA2TilingData &tilingData, std::string &commGroup) +{ + auto attrs = context->GetAttrs(); + OP_TILING_CHECK(attrs == nullptr, OP_LOGE(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED); + + auto sendCountPtr = attrs->GetAttrPointer(ATTR_SEND_COUNT_INDEX); + auto numTokenPtr = attrs->GetAttrPointer(ATTR_NUM_TOKENS_INDEX); + auto topkNumPtr = attrs->GetAttrPointer(ATTR_TOPK_NUM_INDEX); + auto numExpertsPtr = attrs->GetAttrPointer(ATTR_NUM_EXPERTS_INDEX); + auto commGroupPtr = attrs->GetAttrPointer(static_cast(ATTR_COMM_GROUP_INDEX)); + auto rankSizePtr = attrs->GetAttrPointer(ATTR_RANK_SIZE_INDEX); + auto rankIdPtr = attrs->GetAttrPointer(ATTR_RANK_ID_INDEX); + auto localRankSizePtr = attrs->GetAttrPointer(ATTR_LOCAL_RANK_SIZE_INDEX); + auto localRankIdPtr = attrs->GetAttrPointer(ATTR_LOCAL_RANK_ID_INDEX); + + OP_TILING_CHECK((commGroupPtr == nullptr) || (strnlen(commGroupPtr, MAX_GROUP_NAME_LENGTH) == 0) || + (strnlen(commGroupPtr, MAX_GROUP_NAME_LENGTH) == MAX_GROUP_NAME_LENGTH), + OP_LOGE(nodeName, "commGroupPtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(sendCountPtr == nullptr, OP_LOGE(nodeName, "sendCountPtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(numTokenPtr == nullptr, OP_LOGE(nodeName, "numTokenPtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(topkNumPtr == nullptr, OP_LOGE(nodeName, "topkNumPtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(numExpertsPtr == nullptr, OP_LOGE(nodeName, "numExpertsPtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(rankSizePtr == nullptr, OP_LOGE(nodeName, "rankSizePtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(rankIdPtr == nullptr, OP_LOGE(nodeName, "rankIdPtr is null."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(localRankSizePtr == nullptr, OP_LOGE(nodeName, "localRankSizePtr is null."), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(localRankIdPtr == nullptr, OP_LOGE(nodeName, "localRankIdPtr is null."), return ge::GRAPH_FAILED); + + OP_TILING_CHECK((*rankSizePtr <= 0) || (*rankSizePtr > MAX_COMM_WORLD_SIZE), + OP_LOGE(nodeName, "rankSize is invalid, only support (0, %ld], but got rankSize=%ld.", + MAX_COMM_WORLD_SIZE, *rankSizePtr), + return ge::GRAPH_FAILED); + OP_TILING_CHECK( + (*rankIdPtr < 0) || (*rankIdPtr >= *rankSizePtr), + OP_LOGE(nodeName, "rankId is invalid, only support [0, %ld), but got rankId=%ld.", *rankSizePtr, *rankIdPtr), + return ge::GRAPH_FAILED); + OP_TILING_CHECK((*sendCountPtr <= 0), + OP_LOGE(nodeName, "sendCount is invalid, only support > 0, but got sendCount=%ld.", *sendCountPtr), + return ge::GRAPH_FAILED); + OP_TILING_CHECK( + (*numTokenPtr <= 0), + OP_LOGE(nodeName, "numTokenPtr is invalid, only support > 0, but got numTokenPtr=%ld.", *numTokenPtr), + return ge::GRAPH_FAILED); + + commGroup = std::string(commGroupPtr); + tilingData.notifyDispatchInfoA2.rankSize = static_cast(*rankSizePtr); + tilingData.notifyDispatchInfoA2.rankId = static_cast(*rankIdPtr); + tilingData.notifyDispatchInfoA2.localRankSize = static_cast(*localRankSizePtr); + tilingData.notifyDispatchInfoA2.localRankId = static_cast(*localRankIdPtr); + tilingData.notifyDispatchInfoA2.sendCount = static_cast(*sendCountPtr); + tilingData.notifyDispatchInfoA2.numTokens = static_cast(*numTokenPtr); + tilingData.notifyDispatchInfoA2.topkNum = static_cast(*topkNumPtr); + tilingData.notifyDispatchInfoA2.numExperts = static_cast(*numExpertsPtr); + + return ge::GRAPH_SUCCESS; +} + +static void SetHcommCfg(const gert::TilingContext *context, NotifyDispatchA2TilingData *tiling, + const std::string commGroup) +{ + const char *nodeName = context->GetNodeName(); + OP_LOGD(nodeName, "NotifyDispatchA2 commGroup = %s", commGroup.c_str()); + uint32_t opType1 = OP_TYPE_ALL_TO_ALL; + std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise"; + + AscendC::Mc2CcTilingConfig mc2CcTilingConfig(commGroup, opType1, algConfigAllToAllStr); + mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling); + mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling1); +} + +static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName) +{ + size_t *workSpaces = context->GetWorkspaceSizes(1); + OP_TILING_CHECK(workSpaces == nullptr, OP_LOGE(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED); + workSpaces[0] = SYSTEM_NEED_WORKSPACE + KERNEL_USE_WORKSPACE + + KERNEL_A2_ARG_SIZE; // TODO: 多预留空间,dispatch和combine同步要改? + return ge::GRAPH_SUCCESS; +} + +static bool CheckTensorDataType(gert::TilingContext *context, const char *nodeName) +{ + OP_LOGD(nodeName, "========CheckTensorDataType============"); + auto sendData = context->GetInputDesc(INPUT_SEND_DATA_INDEX); + OP_TILING_CHECK(sendData == nullptr, OP_LOGE(nodeName, "sendData is null."), return false); + OP_TILING_CHECK( + (sendData->GetDataType() != ge::DT_BF16) && (sendData->GetDataType() != ge::DT_FLOAT16) && + (sendData->GetDataType() != ge::DT_FLOAT) && (sendData->GetDataType() != ge::DT_INT32), + OP_LOGE(nodeName, + "sendData datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.", + static_cast(sendData->GetDataType())), + return false); + uint64_t dataSize; + if ((sendData->GetDataType() == ge::DT_BF16) || (sendData->GetDataType() == ge::DT_FLOAT16)) { + dataSize = 2; + } else { + dataSize = 4; + } + auto tokenPerExpertData = context->GetInputDesc(INPUT_TOKEN_PER_EXPERT_INDEX); + OP_TILING_CHECK(tokenPerExpertData == nullptr, OP_LOGE(nodeName, "tokenPerExpertData is null."), return false); + OP_TILING_CHECK( + (tokenPerExpertData->GetDataType() != ge::DT_BF16) && (tokenPerExpertData->GetDataType() != ge::DT_FLOAT16) && + (tokenPerExpertData->GetDataType() != ge::DT_FLOAT) && (tokenPerExpertData->GetDataType() != ge::DT_INT32), + OP_LOGE( + nodeName, + "tokenPerExpertData datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.", + static_cast(tokenPerExpertData->GetDataType())), + return false); + auto tmpData = context->GetInputDesc(INPUT_TMP_DATA_INDEX); // 用于算子中临时存数的空间,与recvData相同大小 + OP_TILING_CHECK(tmpData == nullptr, OP_LOGE(nodeName, "tmpData is null."), return false); + OP_TILING_CHECK( + (tmpData->GetDataType() != ge::DT_BF16) && (tmpData->GetDataType() != ge::DT_FLOAT16) && + (tmpData->GetDataType() != ge::DT_FLOAT) && (tmpData->GetDataType() != ge::DT_INT32), + OP_LOGE(nodeName, "tmpData datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.", + static_cast(tmpData->GetDataType())), + return false); + + auto sendDataOffset = context->GetOutputDesc(OUTPUT_SEND_DATA_OFFSET_INDEX); + OP_TILING_CHECK(sendDataOffset == nullptr, OP_LOGE(nodeName, "sendDataOffset is null."), return false); + OP_TILING_CHECK( + (sendDataOffset->GetDataType() != ge::DT_BF16) && (sendDataOffset->GetDataType() != ge::DT_FLOAT16) && + (sendDataOffset->GetDataType() != ge::DT_FLOAT) && (sendDataOffset->GetDataType() != ge::DT_INT32), + OP_LOGE(nodeName, + "sendDataOffset datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.", + static_cast(sendDataOffset->GetDataType())), + return false); + + auto recvData = context->GetOutputDesc(OUTPUT_RECV_DATA_INDEX); + OP_TILING_CHECK(recvData == nullptr, OP_LOGE(nodeName, "recvData is null."), return false); + OP_TILING_CHECK( + (recvData->GetDataType() != ge::DT_BF16) && (recvData->GetDataType() != ge::DT_FLOAT16) && + (recvData->GetDataType() != ge::DT_FLOAT) && (recvData->GetDataType() != ge::DT_INT32), + OP_LOGE(nodeName, + "recvData datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.", + static_cast(recvData->GetDataType())), + return false); + + auto tokenServerIdx = context->GetOutputDesc(OUTPUT_TOKEN_SERVER_IDX_INDEX); + OP_TILING_CHECK(tokenServerIdx == nullptr, OP_LOGE(nodeName, "tokenServerIdx is null."), return false); + OP_TILING_CHECK( + (tokenServerIdx->GetDataType() != ge::DT_BF16) && (tokenServerIdx->GetDataType() != ge::DT_FLOAT16) && + (tokenServerIdx->GetDataType() != ge::DT_FLOAT) && (tokenServerIdx->GetDataType() != ge::DT_INT32), + OP_LOGE(nodeName, + "tokenServerIdx datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.", + static_cast(tokenServerIdx->GetDataType())), + return false); + + auto tokenUniquePerServer = context->GetOutputDesc(OUTPUT_TOKEN_UNIQUE_PER_SERVER_INDEX); + OP_TILING_CHECK(tokenUniquePerServer == nullptr, OP_LOGE(nodeName, "tokenUniquePerServer is null."), return false); + OP_TILING_CHECK( + (tokenUniquePerServer->GetDataType() != ge::DT_BF16) && + (tokenUniquePerServer->GetDataType() != ge::DT_FLOAT16) && + (tokenUniquePerServer->GetDataType() != ge::DT_FLOAT) && + (tokenUniquePerServer->GetDataType() != ge::DT_INT32), + OP_LOGE( + nodeName, + "tokenUniquePerServer datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.", + static_cast(tokenUniquePerServer->GetDataType())), + return false); + + auto epRankTokenCnt = context->GetOutputDesc(OUTPUT_EP_RANK_TOKEN_CNT_INDEX); + OP_TILING_CHECK(epRankTokenCnt == nullptr, OP_LOGE(nodeName, "epRankTokenCnt is null."), return false); + OP_TILING_CHECK( + (epRankTokenCnt->GetDataType() != ge::DT_BF16) && (epRankTokenCnt->GetDataType() != ge::DT_FLOAT16) && + (epRankTokenCnt->GetDataType() != ge::DT_FLOAT) && (epRankTokenCnt->GetDataType() != ge::DT_INT32), + OP_LOGE(nodeName, + "epRankTokenCnt datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.", + static_cast(epRankTokenCnt->GetDataType())), + return false); + + auto localEpTokenCnt = context->GetOutputDesc(OUTPUT_LOCAL_EP_TOKEN_CNT_INDEX); + OP_TILING_CHECK(localEpTokenCnt == nullptr, OP_LOGE(nodeName, "localEpTokenCnt is null."), return false); + OP_TILING_CHECK( + (localEpTokenCnt->GetDataType() != ge::DT_BF16) && (localEpTokenCnt->GetDataType() != ge::DT_FLOAT16) && + (localEpTokenCnt->GetDataType() != ge::DT_FLOAT) && (localEpTokenCnt->GetDataType() != ge::DT_INT32), + OP_LOGE(nodeName, + "localEpTokenCnt datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.", + static_cast(localEpTokenCnt->GetDataType())), + return false); + + auto srcOffsetRankTokenIdx = context->GetOutputDesc(OUTPUT_SRC_OFFSET_RANK_TOKEN_INDEX); + OP_TILING_CHECK(srcOffsetRankTokenIdx == nullptr, OP_LOGE(nodeName, "srcOffsetRankTokenIdx is null."), + return false); + OP_TILING_CHECK( + (srcOffsetRankTokenIdx->GetDataType() != ge::DT_BF16) && + (srcOffsetRankTokenIdx->GetDataType() != ge::DT_FLOAT16) && + (srcOffsetRankTokenIdx->GetDataType() != ge::DT_FLOAT) && + (srcOffsetRankTokenIdx->GetDataType() != ge::DT_INT32), + OP_LOGE( + nodeName, + "srcOffsetRankTokenIdx datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.", + static_cast(srcOffsetRankTokenIdx->GetDataType())), + return false); + + auto dstOffsetRankTokenIdx = context->GetOutputDesc(OUTPUT_DST_OFFSET_RANK_TOKEN_INDEX); + OP_TILING_CHECK(dstOffsetRankTokenIdx == nullptr, OP_LOGE(nodeName, "dstOffsetRankTokenIdx is null."), + return false); + OP_TILING_CHECK( + (dstOffsetRankTokenIdx->GetDataType() != ge::DT_BF16) && + (dstOffsetRankTokenIdx->GetDataType() != ge::DT_FLOAT16) && + (dstOffsetRankTokenIdx->GetDataType() != ge::DT_FLOAT) && + (dstOffsetRankTokenIdx->GetDataType() != ge::DT_INT32), + OP_LOGE( + nodeName, + "dstOffsetRankTokenIdx datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.", + static_cast(dstOffsetRankTokenIdx->GetDataType())), + return false); + + auto offsetInner = context->GetOutputDesc(OUTPUT_OFFSET_INNER_INDEX); + OP_TILING_CHECK(offsetInner == nullptr, OP_LOGE(nodeName, "offsetInner is null."), return false); + OP_TILING_CHECK( + (offsetInner->GetDataType() != ge::DT_BF16) && (offsetInner->GetDataType() != ge::DT_FLOAT16) && + (offsetInner->GetDataType() != ge::DT_FLOAT) && (offsetInner->GetDataType() != ge::DT_INT32), + OP_LOGE(nodeName, + "offsetInner datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.", + static_cast(offsetInner->GetDataType())), + return false); + + auto countOuter = context->GetOutputDesc(OUTPUT_COUNT_OUTER_INDEX); + OP_TILING_CHECK(countOuter == nullptr, OP_LOGE(nodeName, "countOuter is null."), return false); + OP_TILING_CHECK( + (countOuter->GetDataType() != ge::DT_BF16) && (countOuter->GetDataType() != ge::DT_FLOAT16) && + (countOuter->GetDataType() != ge::DT_FLOAT) && (countOuter->GetDataType() != ge::DT_INT32), + OP_LOGE(nodeName, + "countOuter datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.", + static_cast(countOuter->GetDataType())), + return false); + + auto expandIdx = context->GetOutputDesc(OUTPUT_EXPAND_IDX_INDEX); + OP_TILING_CHECK(expandIdx == nullptr, OP_LOGE(nodeName, "expandIdx is null."), return false); + OP_TILING_CHECK( + (expandIdx->GetDataType() != ge::DT_BF16) && (expandIdx->GetDataType() != ge::DT_FLOAT16) && + (expandIdx->GetDataType() != ge::DT_FLOAT) && (expandIdx->GetDataType() != ge::DT_INT32), + OP_LOGE(nodeName, + "expandIdx datatype is invalid, datatype should be bf16 or float16 or float or int, but is %d.", + static_cast(expandIdx->GetDataType())), + return false); + + // Verify the size of the win area + NotifyDispatchA2TilingData *tilingData = context->GetTilingData(); + uint64_t maxWindowSize = Mc2TilingUtils::GetMaxWindowSize(); + uint64_t actualSize = dataSize * tilingData->notifyDispatchInfoA2.sendCount; + if (actualSize > maxWindowSize) { + OP_LOGE(nodeName, "HCCL_BUFFSIZE is too SMALL, should larger than %lu", actualSize); + return false; + } + return true; +} + +static ge::graphStatus TilingCheckTensor(gert::TilingContext *context, const char *nodeName) +{ + OP_TILING_CHECK(!CheckTensorDataType(context, nodeName), OP_LOGE(nodeName, "params dataType is invalid."), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus NotifyDispatchA2TilingFuncImpl(gert::TilingContext *context) +{ + OP_LOGD(nodeName, "Enter NotifyDispatchA2TilingFuncImpl."); + const char *nodeName = context->GetNodeName(); + NotifyDispatchA2TilingData *tilingData = context->GetTilingData(); + OP_TILING_CHECK(tilingData == nullptr, OP_LOGE(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED); + std::string commGroup = ""; + OP_LOGI(nodeName, "Enter NotifyDispatchA2 tiling check func."); + + OP_TILING_CHECK(GetAttrAndSetTilingData(context, nodeName, *tilingData, commGroup) != ge::GRAPH_SUCCESS, + OP_LOGE(nodeName, "Get attr and set tiling data failed."), return ge::GRAPH_FAILED); + + OP_TILING_CHECK(TilingCheckTensor(context, nodeName) != ge::GRAPH_SUCCESS, + OP_LOGE(nodeName, "Tiling check param failed."), return ge::GRAPH_FAILED); + + OP_TILING_CHECK(SetWorkSpace(context, nodeName) != ge::GRAPH_SUCCESS, + OP_LOGE(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED); + SetHcommCfg(context, tilingData, commGroup); + + int tilingKey = TILING_KEY_INT; + auto sendDtype = context->GetInputDesc(0)->GetDataType(); + if (sendDtype == ge::DT_FLOAT16) { + tilingKey = TILING_KEY_FLOAT16; + } else if (sendDtype == ge::DT_BF16) { + tilingKey = TILING_KEY_BFLOAT16; + } else if (sendDtype == ge::DT_FLOAT) { + tilingKey = TILING_KEY_FLOAT; + } + + fe::PlatFormInfos *platformInfoPtr = context->GetPlatformInfo(); + fe::PlatFormInfos &platformInfo = *platformInfoPtr; + + std::string socVersion; + (void)platformInfo.GetPlatformResWithLock("version", "Short_SoC_version", socVersion); + + if (socVersion == "Ascend910B") { + tilingKey = tilingKey + TILING_KEY_A2_TYPE; + } + context->SetTilingKey(tilingKey); + + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + uint32_t blockDim; + uint32_t aivNum = ascendcPlatform.GetCoreNumAiv(); + uint64_t ubSize = 0UL; + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); + + blockDim = aivNum; + context->SetBlockDim(blockDim); + tilingData->notifyDispatchInfoA2.totalUbSize = ubSize; + tilingData->notifyDispatchInfoA2.aivNum = aivNum; + OP_LOGD(nodeName, "blockDim=%u, aivNum=%u, ubSize=%lu", blockDim, aivNum, ubSize); + PrintTilingDataInfo(nodeName, *tilingData); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus NotifyDispatchA2TilingFunc(gert::TilingContext *context) +{ + ge::graphStatus ret = NotifyDispatchA2TilingFuncImpl(context); + return ret; +} + +struct NotifyDispatchA2CompileInfo {}; +ge::graphStatus TilingParseForNotifyDispatchA2(gert::TilingParseContext *context) +{ + (void)context; + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(NotifyDispatchA2) + .Tiling(NotifyDispatchA2TilingFunc) + .TilingParse(TilingParseForNotifyDispatchA2); +} // namespace optiling diff --git a/csrc/deepep/ops2/op_host/op_api/aclnn_dispatch_layout.cpp b/csrc/deepep/ops2/op_host/op_api/aclnn_dispatch_layout.cpp new file mode 100644 index 00000000..402a9eef --- /dev/null +++ b/csrc/deepep/ops2/op_host/op_api/aclnn_dispatch_layout.cpp @@ -0,0 +1,38 @@ +#include +#include "graph/types.h" +#include "aclnn_dispatch_layout.h" +#include "aclnnInner_dispatch_layout.h" + +enum NnopbaseHcclServerType { + NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0, + NNOPBASE_HCCL_SERVER_TYPE_MTE, + NNOPBASE_HCCL_SERVER_TYPE_END +}; +extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType); + +#ifdef __cplusplus +extern "C" { +#endif + +aclnnStatus aclnnDispatchLayoutGetWorkspaceSize(const aclTensor *topkIdx, int64_t numTokens, int64_t numRanks, + int64_t numExperts, int64_t numTopk, int64_t localRankSize, + const aclTensor *numTokensPerRank, const aclTensor *numTokensPerExpert, + const aclTensor *isTokenInRank, const aclTensor *totalData, + uint64_t *workspaceSize, aclOpExecutor **executor) +{ + return aclnnInnerDispatchLayoutGetWorkspaceSize(topkIdx, numTokens, numRanks, numExperts, numTopk, localRankSize, + numTokensPerRank, numTokensPerExpert, isTokenInRank, totalData, + workspaceSize, executor); +} + +aclnnStatus aclnnDispatchLayout(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream) +{ + if (NnopbaseSetHcclServerType) { + NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE); + } + return aclnnInnerDispatchLayout(workspace, workspaceSize, executor, stream); +} + +#ifdef __cplusplus +} +#endif diff --git a/csrc/deepep/ops2/op_host/op_api/aclnn_dispatch_layout.h b/csrc/deepep/ops2/op_host/op_api/aclnn_dispatch_layout.h new file mode 100644 index 00000000..375836d3 --- /dev/null +++ b/csrc/deepep/ops2/op_host/op_api/aclnn_dispatch_layout.h @@ -0,0 +1,42 @@ +#ifndef ACLNN_DISPATCH_LAYOUT_H_ +#define ACLNN_DISPATCH_LAYOUT_H_ + +#include "aclnn/acl_meta.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* function: aclnnDispatchLayoutGetWorkspaceSize + * topkIdx : required + * numTokens : required + * numRanks : required + * numExperts : required + * numTopk : required + * localRankSize : required + * numTokensPerRank : required + * numTokensPerExpert : required + * isTokenInRank : required + * totalData : required + * workspaceSize : size of workspace(output). + * executor : executor context(output). + */ +__attribute__((visibility("default"))) aclnnStatus aclnnDispatchLayoutGetWorkspaceSize( + const aclTensor *topkIdx, int64_t numTokens, int64_t numRanks, int64_t numExperts, int64_t numTopk, + int64_t localRankSize, const aclTensor *numTokensPerRank, const aclTensor *numTokensPerExpert, + const aclTensor *isTokenInRank, const aclTensor *totalData, uint64_t *workspaceSize, aclOpExecutor **executor); + +/* function: aclnnDispatchLayout + * workspace : workspace memory addr(input). + * workspaceSize : size of workspace(input). + * executor : executor context(input). + * stream : acl stream. + */ +__attribute__((visibility("default"))) aclnnStatus aclnnDispatchLayout(void *workspace, uint64_t workspaceSize, + aclOpExecutor *executor, aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/csrc/deepep/ops2/op_host/op_api/aclnn_dispatch_normal_a2.cpp b/csrc/deepep/ops2/op_host/op_api/aclnn_dispatch_normal_a2.cpp new file mode 100644 index 00000000..7cdbb72f --- /dev/null +++ b/csrc/deepep/ops2/op_host/op_api/aclnn_dispatch_normal_a2.cpp @@ -0,0 +1,53 @@ +#include +#include +#include "graph/types.h" +#include "aclnn_dispatch_normal_a2.h" +#include "aclnn/opdev/platform.h" +#include "aclnnInner_dispatch_normal_a2.h" + +enum NnopbaseHcclServerType { + NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0, + NNOPBASE_HCCL_SERVER_TYPE_MTE, + NNOPBASE_HCCL_SERVER_TYPE_END +}; +extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType); + +#ifdef __cplusplus +extern "C" { +#endif + +aclnnStatus aclnnDispatchNormalA2GetWorkspaceSize( + const aclTensor *x, const aclTensor *expertIds, const aclTensor *scales, const aclTensor *xActiveMask, + const aclTensor *expertScales, const aclTensor *tokenServerIdx, const aclTensor *tokenServerCnt, + const aclTensor *epRankTokenCnt, const aclTensor *srcOffsetRankTokenIdx, const aclTensor *dstOffsetRankTokenIdx, + char *groupEp, int64_t epWorldSize, int64_t epRankId, int64_t moeExpertNum, char *groupTp, int64_t tpWorldSize, + int64_t tpRankId, int64_t expertShardType, int64_t sharedExpertNum, int64_t sharedExpertRankNum, int64_t quantMode, + int64_t globalBs, int64_t expertTokenNumsType, const aclTensor *recvX, const aclTensor *dynamicScales, + const aclTensor *expandIdx, const aclTensor *expertTokenNums, const aclTensor *epRecvCount, + const aclTensor *expandScales, const aclTensor *waitRecvCostStats, uint64_t *workspaceSize, + aclOpExecutor **executor) +{ + // printf("[aclnnDispatch] rank:%d", epRankId); + return aclnnInnerDispatchNormalA2GetWorkspaceSize( + x, expertIds, scales, xActiveMask, expertScales, tokenServerIdx, tokenServerCnt, epRankTokenCnt, + srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, groupEp, epWorldSize, epRankId, moeExpertNum, groupTp, + tpWorldSize, tpRankId, expertShardType, sharedExpertNum, sharedExpertRankNum, quantMode, globalBs, + expertTokenNumsType, recvX, dynamicScales, expandIdx, expertTokenNums, epRecvCount, expandScales, + waitRecvCostStats, workspaceSize, executor); +} + +aclnnStatus aclnnDispatchNormalA2(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream) +{ + if (NnopbaseSetHcclServerType) { + if (op::GetCurrentPlatformInfo().GetSocVersion() == op::SocVersion::ASCEND910B) { + NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_AICPU); + } else { + NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE); + } + } + return aclnnInnerDispatchNormalA2(workspace, workspaceSize, executor, stream); +} + +#ifdef __cplusplus +} +#endif diff --git a/csrc/deepep/ops2/op_host/op_api/aclnn_dispatch_normal_a2.h b/csrc/deepep/ops2/op_host/op_api/aclnn_dispatch_normal_a2.h new file mode 100644 index 00000000..3f0f3c35 --- /dev/null +++ b/csrc/deepep/ops2/op_host/op_api/aclnn_dispatch_normal_a2.h @@ -0,0 +1,28 @@ +#ifndef ACLNN_DISPATCH_NORMAL_A2_H_ +#define ACLNN_DISPATCH_NORMAL_A2_H_ + +#include "aclnn/acl_meta.h" + +#ifdef __cplusplus +extern "C" { +#endif + +__attribute__((visibility("default"))) aclnnStatus aclnnDispatchNormalA2GetWorkspaceSize( + const aclTensor *x, const aclTensor *expertIds, const aclTensor *scales, const aclTensor *xActiveMask, + const aclTensor *expertScales, const aclTensor *tokenServerIdx, const aclTensor *tokenServerCnt, + const aclTensor *epRankTokenCnt, const aclTensor *srcOffsetRankTokenIdx, const aclTensor *dstOffsetRankTokenIdx, + char *groupEp, int64_t epWorldSize, int64_t epRankId, int64_t moeExpertNum, char *groupTp, int64_t tpWorldSize, + int64_t tpRankId, int64_t expertShardType, int64_t sharedExpertNum, int64_t sharedExpertRankNum, int64_t quantMode, + int64_t globalBs, int64_t expertTokenNumsType, const aclTensor *recvX, const aclTensor *dynamicScales, + const aclTensor *expandIdx, const aclTensor *expertTokenNums, const aclTensor *epRecvCount, + const aclTensor *expandScales, const aclTensor *waitRecvCostStats, uint64_t *workspaceSize, + aclOpExecutor **executor); + +__attribute__((visibility("default"))) aclnnStatus aclnnDispatchNormalA2(void *workspace, uint64_t workspaceSize, + aclOpExecutor *executor, aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/csrc/deepep/ops2/op_host/op_api/aclnn_moe_distribute_combine_a2.cpp b/csrc/deepep/ops2/op_host/op_api/aclnn_moe_distribute_combine_a2.cpp new file mode 100644 index 00000000..2f974e79 --- /dev/null +++ b/csrc/deepep/ops2/op_host/op_api/aclnn_moe_distribute_combine_a2.cpp @@ -0,0 +1,63 @@ +#include +#include +#include "aclnn_moe_distribute_combine_a2.h" +#include "aclnnInner_moe_distribute_combine_a2.h" +#include "aclnn/opdev/platform.h" +// #include "aclnn_kernels/common/op_error_check.h" +// #include "opdev/op_log.h" +// #include "opdev/common_types.h" +// #include "opdev/platform.h" + +// using namespace op; + +#ifdef __cplusplus +extern "C" { +#endif + +enum NnopbaseHcclServerType { + NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0, + NNOPBASE_HCCL_SERVER_TYPE_MTE, + NNOPBASE_HCCL_SERVER_TYPE_END +}; +extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType); + +static constexpr size_t HCCL_GROUP_NAME_MAX = 128U; + +aclnnStatus aclnnMoeDistributeCombineA2GetWorkspaceSize( + const aclTensor *expandX, const aclTensor *expertIds, const aclTensor *expandIdx, const aclTensor *epSendCounts, + const aclTensor *expertScales, const aclTensor *tpSendCounts, const aclTensor *xActiveMask, + const aclTensor *activationScale, const aclTensor *weightScale, const aclTensor *groupList, + const aclTensor *expandScales, const aclTensor *offsetInner, const aclTensor *offsetOuter, + const aclTensor *countOuter, char *groupEp, int64_t epWorldSize, int64_t epRankId, int64_t moeExpertNum, + char *groupTp, int64_t tpWorldSize, int64_t tpRankId, int64_t expertShardType, int64_t sharedExpertNum, + int64_t sharedExpertRankNum, int64_t globalBs, int64_t outDtype, int64_t commQuantMode, int64_t groupListType, + aclTensor *x, uint64_t *workspaceSize, aclOpExecutor **executor) +{ + // printf("aclnnMoeDistributeCombineA2GetWorkspaceSize"); + + aclnnStatus ret = aclnnInnerMoeDistributeCombineA2GetWorkspaceSize( + expandX, expertIds, expandIdx, epSendCounts, expertScales, tpSendCounts, xActiveMask, activationScale, + weightScale, groupList, expandScales, offsetInner, offsetOuter, countOuter, groupEp, epWorldSize, epRankId, + moeExpertNum, groupTp, tpWorldSize, tpRankId, expertShardType, sharedExpertNum, sharedExpertRankNum, globalBs, + outDtype, commQuantMode, groupListType, x, workspaceSize, executor); + return ret; +} + +aclnnStatus aclnnMoeDistributeCombineA2(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, + aclrtStream stream) +{ + if (NnopbaseSetHcclServerType) { + if (op::GetCurrentPlatformInfo().GetSocVersion() == op::SocVersion::ASCEND910B) { + NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_AICPU); + } else { + NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE); + } + } + // printf("aclnnMoeDistributeCombineA2"); + aclnnStatus ret = aclnnInnerMoeDistributeCombineA2(workspace, workspaceSize, executor, stream); + return ret; +} + +#ifdef __cplusplus +} +#endif diff --git a/csrc/deepep/ops2/op_host/op_api/aclnn_moe_distribute_combine_a2.h b/csrc/deepep/ops2/op_host/op_api/aclnn_moe_distribute_combine_a2.h new file mode 100644 index 00000000..b0032a6a --- /dev/null +++ b/csrc/deepep/ops2/op_host/op_api/aclnn_moe_distribute_combine_a2.h @@ -0,0 +1,90 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef OP_API_INC_MOE_DISTRIBUTE_COMBINE_A2_ +#define OP_API_INC_MOE_DISTRIBUTE_COMBINE_A2_ + +#include +#include "aclnn/aclnn_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * 算子功能:实现MoeDistributeCombine功能。 + * @brief aclnnMoeDistributeCombine的第一段接口,根据具体的计算流程,计算workspace大小。 + * @domain aclnn_ops_infer + * @param [in] expandX: 计算输入,Tensor,数据类型float16,bfloat16,必须为2维,数据格式支持ND。 + * @param [in] expertIds: 计算输入,Tensor,数据类型int32,必须为2维,数据格式支持ND。 + * @param [in] expandIdx: 计算输入,Tensor,数据类型int32,必须为1维,数据格式支持ND,token由多少个MOE专家处理。 + * @param [in] epSendCounts: 计算输入,Tensor,数据类型int32,必须为1维,数据格式支持ND。 + * @param [in] expertScales: 计算输入,Tensor,数据类型float32,必须为2维,数据格式支持ND。 + * @param [in] tpSendCounts: 计算输入,Tensor,数据类型int32,必须为1维,数据格式支持ND。无tp域通信时传空。 + * @param [in] xActiveMask: 计算输入,Tensor,数据类型bool,必须为1维,数据格式支持ND。预留参数,暂未使用,传空即可。 + * @param [in] activationScale: 计算输入,Tensor,数据类型float32,必须为1维,数据格式支持ND, GMM外抛的左矩阵量化系数, + 当x的类型为int32时,该参数必须有。预留参数,暂未使用,传空即可。 + * @param [in] weightScale: 计算输入,Tensor,数据类型float32,必须为2维,数据格式支持ND, GMM外抛的右矩阵量化系数, + 当x的类型为int32时,该参数必须有。预留参数,暂未使用,传空即可。 + * @param [in] groupList: 计算输入,Tensor,数据类型int64,必须为1维,数据格式支持ND, GMM的分组大小, + 当weight scale的E>1时,该参数必须有。预留参数,暂未使用,传空即可。 + * @param [in] expandScales: + 计算输入,Tensor,数据类型float32,必须为1维,数据格式支持ND。在昇腾910_93中暂未使用,传空即可。 + * @param [in] groupEp: 计算输入,str。ep通信域名称,专家并行的通信域。字符串长度范围为[1, 128),不能和groupTp相同。 + * @param [in] epWorldSize: 计算输入,int。ep通信域size。在昇腾910_93场景中取值支持8/16/32/64/128/144/256/288。 + * @param [in] epRankId: 计算输入,int。ep本卡Id。取值范围[0, epWorldSize),同一个EP通信域中各卡的epRankId不能重复。 + * @param [in] moeExpertNum: 计算输入,int。MOE专家数量。在昇腾910_93场景中取值范围[1, 512],且需 + 满足moeExpertNum%(epWorldSize-sharedExpertRankNum)等于0。 + * @param [in] groupTp: 计算可选输入,str。tp通信域名称,数据并行的通信域。无tp通信域时传空, + 有tp通信域时字符串长度范围为[1, 128),不能和groupEp相同。 + * @param [in] tpWorldSize: 计算可选输入,int。tp通信域size。取值范围[0, 2],0和1表示无tp域通信,有tp域通信时仅支持2。 + * @param [in] tpRankId: 计算可选输入,int。tp本卡Id。取值范围[0, tpWorldSize),同一个TP通信域中各卡的tpRankId不能重复。 + * @param [in] expertShardType: 计算可选输入,int。专家共享类型。当前仅支持传0。 + * @param [in] sharedExpertNum: 计算可选输入,int。共享专家数量。取值范围[0, 1]。0表示无共享专家。 + * @param [in] sharedExpertRankNum: 计算可选输入,int。共享专家数量。支持传0表示无共享专家卡,不为0时需满足 + sharedExpertRankNum +#include "graph/types.h" +#include "aclnn_notify_dispatch_a2.h" +#include "aclnnInner_notify_dispatch_a2.h" + +extern void NnopbaseOpLogE(const aclnnStatus code, const char *const expr); + +#ifdef __cplusplus +extern "C" { +#endif + +enum NnopbaseHcclServerType { + NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0, + NNOPBASE_HCCL_SERVER_TYPE_MTE, + NNOPBASE_HCCL_SERVER_TYPE_END +}; +extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType); + +aclnnStatus aclnnNotifyDispatchA2GetWorkspaceSize( + const aclTensor *sendData, const aclTensor *tokenPerExpertData, const aclTensor *tmpData, int64_t sendCount, + int64_t numTokens, int64_t topkNum, int64_t numExperts, char *commGroup, int64_t rankSize, int64_t rankId, + int64_t localRankSize, int64_t localRankId, const aclTensor *sendDataOffset, const aclTensor *recvData, + const aclTensor *tokenServerIdx, const aclTensor *tokenUniquePerServer, const aclTensor *epRankTokenCnt, + const aclTensor *localEpTokenCnt, const aclTensor *srcOffsetRankTokenIdx, const aclTensor *dstOffsetRankTokenIdx, + const aclTensor *offsetInner, const aclTensor *countOuter, const aclTensor *expandIdx, uint64_t *workspaceSize, + aclOpExecutor **executor) +{ + return aclnnInnerNotifyDispatchA2GetWorkspaceSize( + sendData, tokenPerExpertData, tmpData, sendCount, numTokens, topkNum, numExperts, commGroup, rankSize, rankId, + localRankSize, localRankId, sendDataOffset, recvData, tokenServerIdx, tokenUniquePerServer, epRankTokenCnt, + localEpTokenCnt, srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, offsetInner, countOuter, expandIdx, + workspaceSize, executor); +} + +aclnnStatus aclnnNotifyDispatchA2(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream) +{ + if (NnopbaseSetHcclServerType) { + NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_AICPU); // A2 需要为 AICPU + } + return aclnnInnerNotifyDispatchA2(workspace, workspaceSize, executor, stream); +} + +#ifdef __cplusplus +} +#endif diff --git a/csrc/deepep/ops2/op_host/op_api/aclnn_notify_dispatch_a2.h b/csrc/deepep/ops2/op_host/op_api/aclnn_notify_dispatch_a2.h new file mode 100644 index 00000000..6694f476 --- /dev/null +++ b/csrc/deepep/ops2/op_host/op_api/aclnn_notify_dispatch_a2.h @@ -0,0 +1,61 @@ + +#ifndef ACLNN_NOTIFY_DISPATCH_A2_H_ +#define ACLNN_NOTIFY_DISPATCH_A2_H_ + +#include "aclnn/acl_meta.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* function: aclnnNotifyDispatchA2GetWorkspaceSize + * parameters : + * sendData : required + * tokenPerExpertData : required + * sendCount : required + * numTokens : required + * topkNum : required + * numExperts : required + * commGroup : required + * rankSize : required + * rankId : required + * localRankSize : required + * localRankId : required + * sendDataOffset : required + * recvData : required + * tokenServerIdx : required + * tokenUniquePerServer : required + * epRankTokenCnt : required + * localEpTokenCnt : required + * srcOffsetRankTokenIdx : required + * dstOffsetRankTokenIdx : required + * offsetInner : required + * countOuter : required + * expandIdx : required + * workspaceSize : size of workspace(output). + * executor : executor context(output). + */ +__attribute__((visibility("default"))) aclnnStatus aclnnNotifyDispatchA2GetWorkspaceSize( + const aclTensor *sendData, const aclTensor *tokenPerExpertData, const aclTensor *tmpData, int64_t sendCount, + int64_t numTokens, int64_t topkNum, int64_t numExperts, char *commGroup, int64_t rankSize, int64_t rankId, + int64_t localRankSize, int64_t localRankId, const aclTensor *sendDataOffset, const aclTensor *recvData, + const aclTensor *tokenServerIdx, const aclTensor *tokenUniquePerServer, const aclTensor *epRankTokenCnt, + const aclTensor *localEpTokenCnt, const aclTensor *srcOffsetRankTokenIdx, const aclTensor *dstOffsetRankTokenIdx, + const aclTensor *offsetInner, const aclTensor *countOuter, const aclTensor *expandIdx, uint64_t *workspaceSize, + aclOpExecutor **executor); + +/* function: aclnnNotifyDispatch + * parameters : + * workspace : workspace memory addr(input). + * workspaceSize : size of workspace(input). + * executor : executor context(input). + * stream : acl stream. + */ +__attribute__((visibility("default"))) aclnnStatus aclnnNotifyDispatchA2(void *workspace, uint64_t workspaceSize, + aclOpExecutor *executor, aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/csrc/deepep/ops2/op_host/tiling_args.h b/csrc/deepep/ops2/op_host/tiling_args.h new file mode 100644 index 00000000..950cbe90 --- /dev/null +++ b/csrc/deepep/ops2/op_host/tiling_args.h @@ -0,0 +1,9 @@ +#ifndef TILING_ARGS_H +#define TILING_ARGS_H +#include + +namespace Moe { +constexpr uint64_t COMBINE_STATE_WIN_OFFSET = 3U * 1024UL * 1024UL; +constexpr uint64_t NOTIFY_DISPATCH_WIN_OFFSET = 204U * 1024UL * 1024UL; +} // namespace Moe +#endif // TILING_ARGS_H diff --git a/csrc/deepep/ops2/op_kernel/CMakeLists.txt b/csrc/deepep/ops2/op_kernel/CMakeLists.txt new file mode 100644 index 00000000..c8221a5f --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/CMakeLists.txt @@ -0,0 +1,8 @@ +# set custom compile options +if ("${CMAKE_BUILD_TYPE}x" STREQUAL "Debugx") + add_ops_compile_options(ALL OPTIONS -g -O0 ) +endif() + +add_ops_compile_options(ALL OPTIONS -DASCENDC_DUMP=0 --cce-auto-sync=off) + +add_kernels_compile() diff --git a/csrc/deepep/ops2/op_kernel/a2/cam_moe_distribute_dispatch_a2_layered.h b/csrc/deepep/ops2/op_kernel/a2/cam_moe_distribute_dispatch_a2_layered.h new file mode 100644 index 00000000..f43b3c67 --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/a2/cam_moe_distribute_dispatch_a2_layered.h @@ -0,0 +1,1142 @@ +#ifndef CAM_MOE_DISTRIBUTE_DISPATCH_A2_LAYERED_H +#define CAM_MOE_DISTRIBUTE_DISPATCH_A2_LAYERED_H + +#include "kernel_operator.h" +#include "kernel_tiling/kernel_tiling.h" +#include "../cam_moe_distribute_dispatch_tiling.h" +#include "../moe_distribute_base.h" +#include "../comm_args.h" + +namespace MoeDistributeDispatchA2Impl { +constexpr uint32_t STATE_OFFSET = 512; // 状态空间偏移地址 +constexpr uint32_t STATUS_SIZE_LAYERED = 1024 * 1024; // 1M +constexpr uint32_t RDMA_BUFFER_ALIGN = 4 * 1024; +constexpr uint32_t SELF_STATE_OFFSET = 512 * 1024; // 本卡状态空间偏移地址 +constexpr uint32_t SERVER_RANK_SIZE = 8; +constexpr uint32_t INFO_NUM_IN_TOKENSTRUCK = 4; // 在Token后加入3种信息:expIds, weights, tokenIdx, scales +constexpr uint32_t B64_PER_BLOCK = 4; +constexpr uint32_t PER_MSG_RDMA_SEND_TIME = 2; +constexpr uint32_t B32_PER_BLOCK = 8; +constexpr uint32_t UB_32B_ALIGN = 32; +constexpr uint32_t EXP_TOKEN_COUNT_FLAG_CNT = UB_32B_ALIGN / sizeof(int32_t); // 8 +constexpr uint32_t DISPATCH_TOKEN_UB_SIZE = 176 * 1024; +constexpr uint32_t IPC_MAGIC_OFFSET = 2 * 1024 * 1024 - 64 * 32; +constexpr uint32_t IPC_TOKEN_CNT_OFFSET = 2 * 1024 * 1024; +constexpr uint32_t IPC_DATA_OFFSET = 4 * 1024 * 1024; +constexpr uint32_t NOTIFY_OFFSET = 404 * 1024 * 1024; // 204 +constexpr uint32_t IPC_BUFF_ALIGN = 512; +constexpr uint32_t TOKEN_COUNT_SIZE = 32; +constexpr uint32_t FLAG_U32_CNT = TOKEN_COUNT_SIZE / 4; +constexpr int32_t IPC_FLAG_STEP_1 = 1; +constexpr int32_t IPC_FLAG_STEP_2 = 2; +constexpr uint32_t TBUF_TEMP_OFFSET = 8 * 1024; +constexpr uint32_t TBUF_OFFSET_ALIGN_B32_CNT = 2 * 1024 / sizeof(int32_t); +constexpr uint32_t RDMA_DATA_SIZE = 100U * 1024U * 1024U; +constexpr uint32_t EXTRA_TOKEN_INFO_NUM = 4U; // 专家信息 权重信息 量化Scale 到达标志位 +constexpr uint32_t BITS32_PER_BLOCK = 8U; +constexpr static uint32_t BW_ITEM_SIZE = 32; +constexpr uint32_t FLAG_VALUE = 0xFFFFFFFF; +constexpr uint32_t BS_UPPER = 4096; + +#define TemplateMC2TypeA2layeredClass \ + typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist +#define TemplateMC2TypeA2layeredFunc XType, ExpandXOutType, StaticQuant, DynamicQuant, IsSmoothScaleExist + +using namespace AscendC; +using namespace Cam; +template +class CamMoeDistributeDispatchA2Layered +{ + template + inline __aicore__ T RoundUp(const T val, const T align) + { + static_assert(std::is_arithmetic::value, "T must be an arithmetic type"); + if (align == 0 || val + align - 1 < val) { + return val; + } + return (val + align - 1) / align * align; + } + +public: + __aicore__ inline CamMoeDistributeDispatchA2Layered(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expertScales, + GM_ADDR tokenServerIdx, GM_ADDR tokenServerCnt, GM_ADDR epRankTokenCnt, + GM_ADDR srcOffsetRankTokenIdx, GM_ADDR dstOffsetRankTokenIdx, GM_ADDR expandXOut, + GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut, + GM_ADDR epRecvCountsOut, GM_ADDR expandScales, GM_ADDR workspaceGM, TPipe *pipe, + GM_ADDR tilingGM); + __aicore__ inline void Process(); + template + __aicore__ inline void SyncFunc() + { + int32_t eventID = static_cast(GetTPipePtr()->FetchEventID(event)); + AscendC::SetFlag(eventID); + AscendC::WaitFlag(eventID); + } + +private: + __aicore__ inline void Input2Win(); + __aicore__ inline uint32_t GetExpRank(uint32_t expertId); + __aicore__ inline bool IsInSameServer(uint32_t targetRankId); + __aicore__ inline void SetTokenCnt(GlobalTensor globalSet); + __aicore__ inline void CopyTokenToWinOut(uint32_t localTokenIdx, uint32_t tokenIdx, uint32_t dstServerId); + __aicore__ inline void WaitWindow(); + + __aicore__ inline void Win2Ipc(); + __aicore__ inline void Ipc2Out(); + __aicore__ inline void DispatchBetweenServer(); + __aicore__ inline void ConstructDataAndFlagBatchWriteInfo(); + __aicore__ inline void WaitIpcFlag(int32_t flagVal = 1); + __aicore__ inline void SetIpcFlag(int32_t flagVal = 1); + __aicore__ inline void WriteRdmaCntInfo(); + __aicore__ inline void CleanUp(); + __aicore__ inline void QuantProcess(uint32_t sendTokenNum, LocalTensor xTokenLt, + LocalTensor tokenCastLt); + __aicore__ inline int64_t MergeMagicWithValue(int32_t magic, int32_t value); + + TPipe *tpipe_{nullptr}; + GlobalTensor expertIdsGMTensor_; + GlobalTensor expandXOutGMTensor_; + GlobalTensor dynamicScalesOutGMTensor_; + GlobalTensor weightsOutGt; + GlobalTensor dataBatchWriteInfoTensor_; + GlobalTensor sendStatusTensor_; + GlobalTensor readTokensU8Tensor_; + GlobalTensor sendTokensU8Tensor_; + GlobalTensor sendTokensU32Tensor_; + GlobalTensor bufferChosenGlobal_; + GlobalTensor expertToServerGlobalTensor_; + GlobalTensor readStatusTensor_; + GlobalTensor tokenServerIdxGMTensor_; + GlobalTensor tokenServerCntGMTensor_; + + GlobalTensor epRankTokenCntGMTensor_; + GlobalTensor srcOffsetRankTokenIdxGMTensor_; + GlobalTensor dstOffsetRankTokenIdxGMTensor_; + + LocalTensor expertCountTensor_; + LocalTensor batchWriteU64Tensor_; + LocalTensor batchWriteU32Tensor_; + LocalTensor expertToServerCntTensor_; + LocalTensor expertToServerIdxTensor_; + + LocalTensor tokenServerIdxTensor_; + LocalTensor serverCountTensor_; + + TBuf<> tokenServerIdxBuf_; + TBuf<> serverCountBuf_; + + TBuf<> expertCountBuf_; + TBuf<> statusBuf_; + TBuf<> batchWriteInfoBuf_; + TBuf<> expertToServerCntsBuf_; // 总表,int类型只写1/0 + TBuf<> expertToServerIdxBuf_; + TBuf tBuf; + TBuf<> weightBuf_; + + GM_ADDR expandXGM_; + GM_ADDR expandIdxGM_; + GM_ADDR weightsGM_; + GM_ADDR expertTokenNumsOutGM_; + GM_ADDR epRecvCountsGM_; + GM_ADDR statusSpaceGm_; + GM_ADDR windowInGM_; + GM_ADDR windowOutGM_; + GM_ADDR dataBatchWriteInfo_; + GM_ADDR expertToServerCntGM_; + GM_ADDR shareAddrs[8]; + GM_ADDR shareAddrWins[8]; + + // tiling侧已确保数据上限,相乘不会越界,因此统一采用uint32_t进行处理 + uint32_t axisBS_{0}; + uint32_t globalBs_{0}; + uint32_t axisH_{0}; + uint32_t axisK_{0}; + uint32_t kAlign_{0}; + uint32_t aivNum_{0}; + uint32_t expertIdsCnt_{0}; + uint32_t worldSize_{0}; + uint32_t rankId_{0}; + uint32_t aivId_{0}; // aiv id + uint32_t moeExpertNum_{0}; // moe专家卡数, 等于worldSize_ - 共享专家卡数 + uint32_t moeExpertNumInServer_{0}; + uint32_t localMoeExpertNum_{0}; + uint32_t SERVER_SIZE_ON_WIN{0}; + uint32_t RANK_SIZE_ON_IPC{0}; + uint32_t WIN_SIZE{0}; + uint32_t bufferId_{0}; + uint32_t totalSize_{0}; + uint32_t totalWinSize_{0}; + uint32_t halfWinSize_{0}; + uint32_t serverNum{0}; + uint32_t expertTokenNumsType_{0}; + uint32_t shareMemOffset_{0}; + // TokenStruck相关 + uint32_t tokenGapInStruct_{0}; + uint32_t infoGapInStruct_{0}; + uint32_t tokenStructLen_{0}; + uint32_t tokenLenInStruct_{0}; + uint32_t expLenInStruct_{0}; + uint32_t weightLenInStruct_{0}; + uint32_t realLenInStruct_{0}; + uint32_t cntLenInStruct_{0}; + uint32_t expOffsetInStruct_{0}; + uint32_t weightOffsetInStruct_{0}; + uint32_t cntOffsetInStruct_{0}; + uint32_t scaleOffsetInStruct_{0}; + int32_t magicVal_{0}; + + uint32_t combineInnerCntOffset; + uint32_t combineInnerCntIndexOffset; + uint32_t combineOuterCntOffset; + uint32_t combineOuterCntIndexOffset; + + Hccl hccl_; + __gm__ HcclOpResParam *winContext_{nullptr}; +}; + +template +__aicore__ inline void CamMoeDistributeDispatchA2Layered::Init( + GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expertScales, GM_ADDR tokenServerIdx, GM_ADDR tokenServerCnt, + GM_ADDR epRankTokenCnt, GM_ADDR srcOffsetRankTokenIdx, GM_ADDR dstOffsetRankTokenIdx, GM_ADDR expandXOut, + GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut, GM_ADDR epRecvCountsOut, + GM_ADDR expandScales, GM_ADDR workspaceGM, TPipe *pipe, GM_ADDR tilingGM) +{ + PRINTF("[A2layer Init]\n"); + tpipe_ = pipe; + REGISTER_TILING_DEFAULT(CamMoeDistributeDispatchA2TilingData); + auto tiling = (__gm__ CamMoeDistributeDispatchA2TilingData *)tilingGM; + __gm__ void *mc2InitTiling = (__gm__ void *)(&(tiling->mc2InitTiling)); + __gm__ void *mc2CcTiling = (__gm__ void *)(&(tiling->mc2CcTiling)); + GET_TILING_DATA_WITH_STRUCT(CamMoeDistributeDispatchA2TilingData, tilingData, tilingGM); + + auto contextGM0 = AscendC::GetHcclContext(); + hccl_.Init(contextGM0, mc2InitTiling); + hccl_.SetCcTiling(mc2CcTiling); + + winContext_ = (__gm__ HcclOpResParam *)contextGM0; + rankId_ = tilingData.moeDistributeDispatchInfo.epRankId; + windowInGM_ = hccl_.GetWindowsInAddr(rankId_) + NOTIFY_OFFSET; + windowOutGM_ = hccl_.GetWindowsOutAddr(rankId_); + // return; + + axisBS_ = tilingData.moeDistributeDispatchInfo.bs; + globalBs_ = tilingData.moeDistributeDispatchInfo.globalBs; + axisH_ = tilingData.moeDistributeDispatchInfo.h; + axisK_ = tilingData.moeDistributeDispatchInfo.k; + aivNum_ = tilingData.moeDistributeDispatchInfo.aivNum; + worldSize_ = tilingData.moeDistributeDispatchInfo.epWorldSize; + moeExpertNum_ = tilingData.moeDistributeDispatchInfo.moeExpertNum; + localMoeExpertNum_ = moeExpertNum_ / worldSize_; + kAlign_ = RoundUp(axisK_, (uint32_t)8); + totalSize_ = winContext_->winSize; + totalWinSize_ = 1000 * 1024 * 1024; // RDMA 1000 MB空间 + shareMemOffset_ = totalWinSize_; + halfWinSize_ = totalWinSize_ / 2; + WIN_SIZE = halfWinSize_ - STATUS_SIZE_LAYERED; + expertTokenNumsType_ = tilingData.moeDistributeDispatchInfo.expertTokenNumsType; + // 校验待完善 + /* + uint64_t winSizeMin = + moeExpertNum_ * axisBS_ * (axisH_ * sizeof(XType) + EXTRA_TOKEN_INFO_NUM * kAlign_ * sizeof(uint32_t)) + + IPC_DATA_OFFSET + RDMA_DATA_SIZE; // 考虑负载极其不均衡时,HCCL BUFFSIZE需要开的大小 + assert(winContext_->winSize >= winSizeMin, + "The HCCL_BUFFSIZE is %lluMB, the min value should be %lluMB. \ + epWorldSize:%u, epRankId:%u, moeExpertNum:%u, quantMode:%u, globalBs:%u, bs:%u, k:%u, h:%u, aivNum:%u, \ + isQuant:%d, totalUbSize:%llu, expertTokenNumsType:%u\n", + winContext_->winSize / MB_SIZE, + winSizeMin / MB_SIZE, + tilingData.moeDistributeDispatchInfo.epWorldSize, + tilingData.moeDistributeDispatchInfo.epRankId, + tilingData.moeDistributeDispatchInfo.moeExpertNum, + tilingData.moeDistributeDispatchInfo.quantMode, + tilingData.moeDistributeDispatchInfo.globalBs, + tilingData.moeDistributeDispatchInfo.bs, + tilingData.moeDistributeDispatchInfo.k, + tilingData.moeDistributeDispatchInfo.h, + tilingData.moeDistributeDispatchInfo.aivNum, + tilingData.moeDistributeDispatchInfo.isQuant, + tilingData.moeDistributeDispatchInfo.totalUbSize, + tilingData.moeDistributeDispatchInfo.expertTokenNumsType); + */ + for (int i = 0; i < SERVER_RANK_SIZE; i++) { + shareAddrs[i] = (__gm__ uint8_t *)(reinterpret_cast( + hccl_.GetWindowsInAddr(rankId_ / SERVER_RANK_SIZE * SERVER_RANK_SIZE + i) + shareMemOffset_ + + NOTIFY_OFFSET)); + shareAddrWins[i] = (__gm__ uint8_t *)(reinterpret_cast( + hccl_.GetWindowsInAddr(rankId_ / SERVER_RANK_SIZE * SERVER_RANK_SIZE + i) + NOTIFY_OFFSET + + halfWinSize_ * bufferId_)); + } + + // struce相关信息初始化计算 + tokenStructLen_ = + axisH_ * sizeof(ExpandXOutType) + INFO_NUM_IN_TOKENSTRUCK * (kAlign_ * sizeof(uint32_t)); // token和四元组大小 + tokenLenInStruct_ = axisH_ * sizeof(ExpandXOutType); // 纯token大小 + expLenInStruct_ = kAlign_ * sizeof(uint32_t); // topkId大小 + weightLenInStruct_ = kAlign_ * sizeof(uint32_t); // weight大小 + cntLenInStruct_ = kAlign_ * sizeof(uint32_t); // tokenIdx大小 + realLenInStruct_ = axisK_ * sizeof(uint32_t); // 内存中实际有效部分,跟 axisK_ 有关 + expOffsetInStruct_ = tokenLenInStruct_; // 开始写topkId的起始位置 + weightOffsetInStruct_ = tokenLenInStruct_ + expLenInStruct_; // 开始写weight的起始位置 + cntOffsetInStruct_ = tokenLenInStruct_ + expLenInStruct_ + weightLenInStruct_; // 开始写tokenIdx的起始位置 + scaleOffsetInStruct_ = + tokenLenInStruct_ + expLenInStruct_ + weightLenInStruct_ + cntLenInStruct_; // 开始写scales的起始位置 + tokenGapInStruct_ = (tokenStructLen_ - tokenLenInStruct_) / UB_32B_ALIGN; + infoGapInStruct_ = (tokenStructLen_ - expLenInStruct_) / UB_32B_ALIGN; + + RANK_SIZE_ON_IPC = (totalSize_ - totalWinSize_ - IPC_DATA_OFFSET) / (localMoeExpertNum_ * worldSize_); + RANK_SIZE_ON_IPC = (RANK_SIZE_ON_IPC / IPC_BUFF_ALIGN) * IPC_BUFF_ALIGN; + + aivId_ = GetBlockIdx(); + expertIdsCnt_ = axisBS_ * axisK_; + serverNum = worldSize_ / SERVER_RANK_SIZE; + SERVER_SIZE_ON_WIN = WIN_SIZE / serverNum; + SERVER_SIZE_ON_WIN = (SERVER_SIZE_ON_WIN / RDMA_BUFFER_ALIGN) * RDMA_BUFFER_ALIGN; // 共享内存上每个server块的大小 + + bufferChosenGlobal_.SetGlobalBuffer((__gm__ uint32_t *)(windowInGM_ + WIN_SIZE + worldSize_ * STATE_OFFSET)); + bufferId_ = bufferChosenGlobal_(0); + + windowInGM_ = windowInGM_ + halfWinSize_ * bufferId_; + windowOutGM_ = windowOutGM_ + halfWinSize_ * bufferId_; + + tokenServerIdxGMTensor_.SetGlobalBuffer((__gm__ int32_t *)tokenServerIdx); + tokenServerCntGMTensor_.SetGlobalBuffer((__gm__ int32_t *)tokenServerCnt); + expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)expertIds); + epRankTokenCntGMTensor_.SetGlobalBuffer((__gm__ int32_t *)epRankTokenCnt); + srcOffsetRankTokenIdxGMTensor_.SetGlobalBuffer((__gm__ int32_t *)srcOffsetRankTokenIdx); + dstOffsetRankTokenIdxGMTensor_.SetGlobalBuffer((__gm__ int32_t *)dstOffsetRankTokenIdx); + + expandXOutGMTensor_.SetGlobalBuffer((__gm__ ExpandXOutType *)(expandXOut), + worldSize_ * axisBS_ * localMoeExpertNum_ * axisH_); + dynamicScalesOutGMTensor_.SetGlobalBuffer((__gm__ float *)(dynamicScalesOut)); + + weightsOutGt.SetGlobalBuffer((__gm__ float *)(expandScales)); + + sendTokensU8Tensor_.SetGlobalBuffer((__gm__ uint8_t *)(windowOutGM_)); + readTokensU8Tensor_.SetGlobalBuffer((__gm__ uint8_t *)(windowInGM_)); + sendTokensU32Tensor_.SetGlobalBuffer((__gm__ uint32_t *)(windowOutGM_)); + sendStatusTensor_.SetGlobalBuffer((__gm__ int32_t *)(windowOutGM_ + WIN_SIZE)); + readStatusTensor_.SetGlobalBuffer((__gm__ int32_t *)(windowInGM_ + WIN_SIZE)); + + expertTokenNumsOutGM_ = expertTokenNumsOut; // 无GlobalTensor + epRecvCountsGM_ = epRecvCountsOut; // 无GlobalTensor + statusSpaceGm_ = windowInGM_ + WIN_SIZE; + + expandXGM_ = x; + expandIdxGM_ = expertIds; + weightsGM_ = expertScales; + + dataBatchWriteInfo_ = workspaceGM; + dataBatchWriteInfoTensor_.SetGlobalBuffer((__gm__ uint64_t *)(dataBatchWriteInfo_), + serverNum * PER_MSG_RDMA_SEND_TIME * B64_PER_BLOCK); + + expertToServerCntGM_ = dataBatchWriteInfo_ + serverNum * PER_MSG_RDMA_SEND_TIME * B64_PER_BLOCK * sizeof(uint64_t); + expertToServerGlobalTensor_.SetGlobalBuffer((__gm__ uint32_t *)(expertToServerCntGM_), + RoundUp(axisBS_ * serverNum, B32_PER_BLOCK)); + + combineInnerCntOffset = localMoeExpertNum_ * serverNum * SERVER_RANK_SIZE * sizeof(int32_t); + combineInnerCntIndexOffset = combineInnerCntOffset + globalBs_ * serverNum * sizeof(int32_t); + combineOuterCntOffset = combineInnerCntIndexOffset + globalBs_ * axisK_ * serverNum * sizeof(int32_t); + combineOuterCntIndexOffset = combineOuterCntOffset + axisBS_ * sizeof(int32_t); + moeExpertNumInServer_ = SERVER_RANK_SIZE * localMoeExpertNum_; + + tpipe_->InitBuffer(batchWriteInfoBuf_, PER_MSG_RDMA_SEND_TIME * BW_ITEM_SIZE); // 2 * 32 + + batchWriteU64Tensor_ = batchWriteInfoBuf_.Get(); + batchWriteU32Tensor_ = batchWriteU64Tensor_.template ReinterpretCast(); + + // tpipe_->InitBuffer(expertToServerCntsBuf_, RoundUp(static_cast(axisBS_ * serverNum * sizeof(uint32_t)), + // UB_32B_ALIGN)); // bs * rankSize / 8 * 4 + // expertToServerCntTensor_ = expertToServerCntsBuf_.Get(); + // Duplicate(expertToServerCntTensor_, 0, + // static_cast(RoundUp(static_cast(axisBS_ * serverNum), B32_PER_BLOCK))); + + tpipe_->InitBuffer(statusBuf_, UB_32B_ALIGN); // 32 + + tpipe_->InitBuffer(expertToServerIdxBuf_, serverNum * sizeof(uint32_t)); // rankSize / 8 * 4 + expertToServerIdxTensor_ = expertToServerIdxBuf_.Get(); + + tpipe_->InitBuffer(expertCountBuf_, moeExpertNum_ * sizeof(int32_t)); // moeNum * 4 + expertCountTensor_ = expertCountBuf_.Get(); + Duplicate(expertCountTensor_, 0, moeExpertNum_); + + tpipe_->InitBuffer(tBuf, DISPATCH_TOKEN_UB_SIZE); // 176K + tpipe_->InitBuffer(weightBuf_, UB_32B_ALIGN); // 32 + + GlobalTensor selfStatusTensor; + selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusSpaceGm_ + SELF_STATE_OFFSET)); + int32_t state = selfStatusTensor(aivId_ * UB_32B_ALIGN); + PipeBarrier(); + + if (aivId_ == 0) { + sendStatusTensor_.SetValue(0, FLAG_VALUE); + DataCacheCleanAndInvalid( + sendStatusTensor_); + } + + LocalTensor tempLocal = tBuf.Get(); + + // 每次调用magic++,用来区分不同轮次 + GlobalTensor magicGt; + magicGt.SetGlobalBuffer((__gm__ int32_t *)(shareAddrs[rankId_ % SERVER_RANK_SIZE] + IPC_MAGIC_OFFSET) + + aivId_ * EXP_TOKEN_COUNT_FLAG_CNT); + tempLocal(0) = 1; + // 使用atomic方式实现+1 + AscendC::SetAtomicAdd(); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); // 等待SetValue完成 + DataCopy(magicGt, tempLocal, EXP_TOKEN_COUNT_FLAG_CNT); + AscendC::SetAtomicNone(); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); // 等待SetValue完成 + magicVal_ = magicGt.GetValue(0); + PipeBarrier(); +} + +template +__aicore__ inline void CamMoeDistributeDispatchA2Layered::Input2Win() +{ + uint32_t sendTokenNum = axisBS_ / aivNum_; + uint32_t remainderTokenNum = axisBS_ % aivNum_; + uint32_t startTokenId = sendTokenNum * aivId_; + // 分核,每个Core处理sendTokenNum个Token的遍历 + if (aivId_ < remainderTokenNum) { // 前remainderRankNum个aiv需要多发1个卡的数据 + sendTokenNum += 1; + startTokenId += aivId_; + } else { + startTokenId += remainderTokenNum; + } + uint32_t endTokenId = startTokenId + sendTokenNum; + + if (sendTokenNum == 0) { + return; + } + int32_t expertId = 0; + uint32_t dstServerId = 0; + uint32_t tokenIndex = 0; + + uint32_t tokenUbSize = tokenStructLen_; + if constexpr (DynamicQuant || StaticQuant) { + tokenUbSize = axisH_ * sizeof(XType); + } + + tpipe_->InitBuffer(tokenServerIdxBuf_, sendTokenNum * serverNum * sizeof(int32_t)); + tpipe_->InitBuffer(serverCountBuf_, serverNum * sizeof(int32_t)); + + tokenServerIdxTensor_ = tokenServerIdxBuf_.Get(); + DataCopyExtParams tokenServerIdxParams = {1U, static_cast(sendTokenNum * serverNum * sizeof(int32_t)), 0U, + 0U, 0U}; + DataCopyPadExtParams copyPadExtParams{false, 0U, 0U, 0U}; + DataCopyPad(tokenServerIdxTensor_, tokenServerIdxGMTensor_[startTokenId * serverNum], tokenServerIdxParams, + copyPadExtParams); + + // 这几个tensor是相同的地址空间,只是数据类型不一样 + LocalTensor tokenTempTensorU8_ = + tBuf.GetWithOffset(((tokenUbSize) / sizeof(uint8_t)), TBUF_TEMP_OFFSET); + LocalTensor tokenTempTensorU32_ = + tBuf.GetWithOffset(((tokenUbSize) / sizeof(uint32_t)), TBUF_TEMP_OFFSET); + LocalTensor tokenLt = tBuf.GetWithOffset(((tokenUbSize) / sizeof(XType)), TBUF_TEMP_OFFSET); + + GlobalTensor xGMTensorU8_; + xGMTensorU8_.SetGlobalBuffer((__gm__ uint8_t *)expandXGM_); + GlobalTensor expertIdsGMTensorU8_; + expertIdsGMTensorU8_.SetGlobalBuffer((__gm__ uint8_t *)expandIdxGM_); + + GlobalTensor expertIdsGMTensorU32_; + expertIdsGMTensorU32_.SetGlobalBuffer((__gm__ uint32_t *)expandIdxGM_); + + GlobalTensor weightGt; + weightGt.SetGlobalBuffer((__gm__ uint8_t *)weightsGM_); + + DataCopyExtParams tokenCopyParamsQuant{1, static_cast(axisH_ * sizeof(XType)), 0, 0, 0}; + DataCopyExtParams tokenCopyParamsNoQuant{static_cast(1), static_cast(tokenLenInStruct_), 0, 0, + 0}; + DataCopyPadExtParams tokenPadParams; + + DataCopyExtParams expCopyParams{static_cast(1), static_cast(realLenInStruct_), 0, 0, 0}; + DataCopyPadExtParams expPadParams; + + DataCopyExtParams weightCopyParams{static_cast(1), static_cast(realLenInStruct_), 0, 0, 0}; + DataCopyPadExtParams weightPadParams; + + for (int i = 0; i < sendTokenNum; i++) { + if constexpr (DynamicQuant || StaticQuant) { + DataCopyPad(tokenTempTensorU8_, xGMTensorU8_[(startTokenId + i) * axisH_ * sizeof(XType)], + tokenCopyParamsQuant, tokenPadParams); + LocalTensor tokenCastLt = tBuf.GetWithOffset( + ((axisH_ * sizeof(float)) / sizeof(float)), RoundUp(TBUF_TEMP_OFFSET + tokenUbSize, B32_PER_BLOCK)); + QuantProcess(1, tokenLt, tokenCastLt); + } else { + DataCopyPad(tokenTempTensorU8_, xGMTensorU8_[(startTokenId + i) * tokenLenInStruct_], + tokenCopyParamsNoQuant, tokenPadParams); + } + // 拷贝topkIds 可省略 + DataCopyPad(tokenTempTensorU8_[expOffsetInStruct_], expertIdsGMTensorU8_[(startTokenId + i) * realLenInStruct_], + expCopyParams, expPadParams); + + // LocalTensor exd =tokenTempTensorU8_[expOffsetInStruct_].template ReinterpretCast(); + // AscendC::DumpTensor(exd, 475, 32); + // PRINTF("[Input2Win] rank:%d, coreId:%d, weightGt:%d \n", rankId_, aivId_, weightGt[(startTokenId + i) * + // realLenInStruct_].GetValue(0)); 拷贝weight + DataCopyPad(tokenTempTensorU8_[weightOffsetInStruct_], weightGt[(startTokenId + i) * realLenInStruct_], + weightCopyParams, weightPadParams); + + // LocalTensor weigt = tokenTempTensorU8_[weightOffsetInStruct_].template ReinterpretCast(); + // AscendC::DumpTensor(weigt, 482, 32); + SyncFunc(); + for (int j = 0; j < serverNum; j++) { + if (tokenServerIdxTensor_(i * serverNum + j) == -1) { + continue; + } + uint32_t destOffset = + j * SERVER_SIZE_ON_WIN + tokenStructLen_ * tokenServerIdxTensor_(i * serverNum + j) + TOKEN_COUNT_SIZE; + // uint32_t destOffset = + // j * SERVER_SIZE_ON_WIN + tokenStructLen_ * tokenServerIdxTensor_(i * serverNum + j); + DataCopy(sendTokensU8Tensor_[destOffset], tokenTempTensorU8_[0], tokenStructLen_); + + // GlobalTensor sendTokenU32; + // sendTokenU32.SetGlobalBuffer((__gm__ int32_t *)(windowOutGM_)); + // AscendC::DumpTensor(sendTokenU32[(destOffset + expOffsetInStruct_) / 4], 495, 32); + + // GlobalTensor sendTokenU32_wt; + // sendTokenU32_wt.SetGlobalBuffer((__gm__ float *)(windowOutGM_)); + // AscendC::DumpTensor(sendTokenU32_wt[(destOffset + weightOffsetInStruct_) / 4], 499, 32); + + if (j == rankId_ / SERVER_RANK_SIZE) { + DataCopy(readTokensU8Tensor_[destOffset], tokenTempTensorU8_[0], tokenStructLen_); + } + } + SyncFunc(); + } + // AscendC::DumpTensor(sendTokensU8Tensor_, 497, SERVER_SIZE_ON_WIN); +} + +template +__aicore__ inline void CamMoeDistributeDispatchA2Layered::QuantProcess( + uint32_t sendTokenNum, LocalTensor xTokenLt, LocalTensor tokenCastLt) +{ + constexpr uint32_t maxArrUbOffset = 6 * 1024; + constexpr uint32_t maxArrLen = 3; + constexpr uint32_t maxValOffset = 0; + constexpr uint32_t minValOffset = 1; + constexpr uint32_t resValOffset = 2; + constexpr float quantMax = 127.0f; + const half deqScale = static_cast(1.000000e+00f); + float dynamicScale = 0.0; + PipeBarrier(); + LocalTensor workLt = tBuf.GetWithOffset(maxArrUbOffset / sizeof(float), 0); + LocalTensor maxLt = tBuf.GetWithOffset(maxArrLen, maxArrUbOffset); + Cast(tokenCastLt, xTokenLt, RoundMode::CAST_NONE, sendTokenNum * axisH_); + for (int32_t i = 0; i < sendTokenNum; ++i) { + PipeBarrier(); + if constexpr (DynamicQuant) { + ReduceMax(maxLt[maxValOffset], tokenCastLt[i * axisH_], workLt, axisH_, false); + SyncFunc(); + PipeBarrier(); + ReduceMin(maxLt[minValOffset], tokenCastLt[i * axisH_], workLt, axisH_, false); + PipeBarrier(); + Abs(maxLt, maxLt, maxArrLen - 1); + PipeBarrier(); + ReduceMax(maxLt[resValOffset], maxLt, workLt, maxArrLen - 1, false); + + SyncFunc(); + float maxVal = maxLt(resValOffset); + dynamicScale = float(quantMax) / float(maxVal); + SyncFunc(); + Muls(tokenCastLt[i * axisH_], tokenCastLt[i * axisH_], dynamicScale, axisH_); + PipeBarrier(); + } + + LocalTensor halfLocalTemp = tokenCastLt[i * axisH_].template ReinterpretCast(); + LocalTensor int32LocalTemp = tokenCastLt[i * axisH_].template ReinterpretCast(); + Cast(int32LocalTemp, tokenCastLt[i * axisH_], RoundMode::CAST_RINT, axisH_); + PipeBarrier(); + SetDeqScale(deqScale); + PipeBarrier(); + + Cast(halfLocalTemp, int32LocalTemp, RoundMode::CAST_ROUND, axisH_); + + PipeBarrier(); + LocalTensor xOutTensor; + LocalTensor tokenUnitLt; + tokenUnitLt = xTokenLt.template ReinterpretCast(); + xOutTensor = tokenUnitLt[i * tokenStructLen_].template ReinterpretCast(); + Cast(xOutTensor, halfLocalTemp, RoundMode::CAST_TRUNC, axisH_); + + LocalTensor scaleTensor = + tokenUnitLt[i * tokenStructLen_ + scaleOffsetInStruct_].template ReinterpretCast(); + scaleTensor.SetValue(0, float(1.0) / dynamicScale); // int8->float32 + } +} + +// template +// __aicore__ inline void CamMoeDistributeDispatchA2Layered::CopyTokenToWinOut( +// uint32_t localTokenIdx, uint32_t globalTokenIdx, uint32_t dstServerId) +// { +// uint32_t curServerId = rankId_ / SERVER_RANK_SIZE; +// uint32_t toServerCntSum = 0; +// SyncFunc(); + +// for (uint32_t tokenIdx = 0; tokenIdx < globalTokenIdx; tokenIdx++) { +// uint32_t tensorOffset = tokenIdx * serverNum + dstServerId; +// toServerCntSum += expertToServerCntTensor_(tensorOffset); +// } + +// LocalTensor tokenTempTensorU8_ = tBuf.GetWithOffset((DISPATCH_TOKEN_UB_SIZE), +// TBUF_TEMP_OFFSET); SyncFunc(); uint32_t destOffset = dstServerId * SERVER_SIZE_ON_WIN +// + tokenStructLen_ * toServerCntSum + TOKEN_COUNT_SIZE; DataCopy(sendTokensU8Tensor_[destOffset], +// tokenTempTensorU8_[localTokenIdx * tokenStructLen_], tokenStructLen_); +// } + +template +__aicore__ inline void CamMoeDistributeDispatchA2Layered::WriteRdmaCntInfo() +{ + uint32_t destServerNum = serverNum / aivNum_; // 每个AIV要处理的server数 + uint32_t remaServerNum = serverNum % aivNum_; + uint32_t startServerId = destServerNum * aivId_; + if (aivId_ < remaServerNum) { // 前remainderRankNum个aiv需要多发1个卡的数据 + destServerNum += 1; + startServerId += aivId_; + } else { + startServerId += remaServerNum; + } + if (destServerNum == 0) { + return; + } + + tpipe_->InitBuffer(serverCountBuf_, serverNum * sizeof(int32_t)); + serverCountTensor_ = serverCountBuf_.Get(); + DataCopyExtParams serverCountParams = {1U, static_cast(serverNum * sizeof(int32_t)), 0U, 0U, 0U}; + DataCopyPadExtParams copyPadExtParams{false, 0U, 0U, 0U}; + DataCopyPad(serverCountTensor_, tokenServerCntGMTensor_[0], serverCountParams, copyPadExtParams); + SyncFunc(); + for (uint32_t dstServerId = startServerId; dstServerId < startServerId + destServerNum; ++dstServerId) { + uint32_t dstServerCnt = serverCountTensor_(dstServerId); + expertToServerIdxTensor_(dstServerId) = dstServerCnt; + LocalTensor writeCntLt = tBuf.GetWithOffset(EXP_TOKEN_COUNT_FLAG_CNT, 0); + writeCntLt.SetValue(0, dstServerCnt); + uint32_t destOffset = (dstServerId * SERVER_SIZE_ON_WIN) / sizeof(uint32_t); + + SyncFunc(); + // DataCopy(sendTokensU32Tensor_[destOffset], writeCntLt, EXP_TOKEN_COUNT_FLAG_CNT); + } +} + +// 构建发往其他server的所有data报文 +template +__aicore__ inline void +CamMoeDistributeDispatchA2Layered::ConstructDataAndFlagBatchWriteInfo() +{ + // 计算当前core要处理的server + uint32_t batchWriteItemNum = serverNum / aivNum_; // 一个aiv负责的server数量 + uint32_t remainderItemNum = serverNum % aivNum_; // 多出来的server没人处理 + uint32_t startServerId = batchWriteItemNum * aivId_; // 当前aiv负责[startServerId,endServerId)个server + uint32_t curServerId = rankId_ / SERVER_RANK_SIZE; // 当前serverId + + if (aivId_ < remainderItemNum) { + startServerId += aivId_; // aiv0:1*0+0=0,aiv1:1*1+1=2,aiv2:1*2+2=4,... aiv23:1*23+23=46, + batchWriteItemNum += 1; // 前remainderItemNum个aiv需要多处理1个server的数据 + } else { + startServerId += remainderItemNum; // aiv24:1*24+24=48, aiv25:1*25+24=49 + } + uint32_t endServerId = startServerId + batchWriteItemNum; + if (batchWriteItemNum == 0) { + return; + } + // 当前aiv负责 [startServerId,endServerId) 个 server + for (uint32_t dstserverInd = startServerId; dstserverInd < endServerId; ++dstserverInd) { + uint32_t sendIdx = dstserverInd - startServerId; + uint32_t dstRankId = rankId_ % SERVER_RANK_SIZE + dstserverInd * SERVER_RANK_SIZE; // 目标Rank + PipeBarrier(); + uint64_t dstDataRdmaAddr = (uint64_t)(hccl_.GetWindowsInAddr(dstRankId) + NOTIFY_OFFSET + + halfWinSize_ * bufferId_ + curServerId * SERVER_SIZE_ON_WIN); + // src卡GetWindowsInAddr地址, 要发给serverIndex,即是本端的rdma地址 + uint64_t srcDataRdmaAddr = + (uint64_t)(hccl_.GetWindowsOutAddr(rankId_) + halfWinSize_ * bufferId_ + dstserverInd * SERVER_SIZE_ON_WIN); + + // for (int j = 0; j < 16; ++j) { + // GlobalTensor sendTokenU32; + // sendTokenU32.SetGlobalBuffer((__gm__ int32_t *)(hccl_.GetWindowsOutAddr(rankId_) + halfWinSize_ * + // bufferId_ + dstserverInd * SERVER_SIZE_ON_WIN + TOKEN_COUNT_SIZE)); + // AscendC::DumpTensor(sendTokenU32[(expOffsetInStruct_) / 4], 658, 32); + + // GlobalTensor sendTokenU32_wt; + // sendTokenU32_wt.SetGlobalBuffer((__gm__ float *)(hccl_.GetWindowsOutAddr(rankId_) + halfWinSize_ * + // bufferId_ + dstserverInd * SERVER_SIZE_ON_WIN + TOKEN_COUNT_SIZE)); + // AscendC::DumpTensor(sendTokenU32_wt[(weightOffsetInStruct_) / 4], 662, 32); + // } + + // 去往该Server的传输的数据量 + uint32_t validTokenCount = expertToServerIdxTensor_(dstserverInd); + PRINTF("[BatchWriteInfo] rank:%d, aivId_:%d, dstServer:%d, tokenCnt:%d\n", rankId_, aivId_, dstserverInd, + validTokenCount); + uint32_t validDataLength = TOKEN_COUNT_SIZE + validTokenCount * tokenStructLen_; + // uint32_t validDataLength = validTokenCount * tokenStructLen_; + uint64_t winInAddr = (uint64_t)(hccl_.GetWindowsInAddr(rankId_) + NOTIFY_OFFSET); + uint64_t winOutAddr = (uint64_t)(hccl_.GetWindowsOutAddr(rankId_)); + PipeBarrier(); + batchWriteU64Tensor_(0) = srcDataRdmaAddr; // 源地址 + batchWriteU64Tensor_(1) = dstDataRdmaAddr; // 目的地址 + batchWriteU64Tensor_(2) = validDataLength; // 数据长度 + batchWriteU32Tensor_(6) = HcclDataType::HCCL_DATA_TYPE_INT8; + batchWriteU32Tensor_(7) = dstRankId; // dst卡 + + uint64_t dstFlagRdmaAddr = (uint64_t)(hccl_.GetWindowsInAddr(dstRankId) + NOTIFY_OFFSET + + halfWinSize_ * bufferId_ + WIN_SIZE + curServerId * STATE_OFFSET); + + // src卡,即是本端的rdma地址 + uint64_t srcFlagRdmaAddr = (uint64_t)(sendStatusTensor_.GetPhyAddr()); + uint32_t flagLen = TOKEN_COUNT_SIZE; + PipeBarrier(); + batchWriteU64Tensor_(4) = srcFlagRdmaAddr; // 源地址 + batchWriteU64Tensor_(5) = dstFlagRdmaAddr; // 目的地址 + batchWriteU64Tensor_(6) = flagLen; // 数据长度 + batchWriteU32Tensor_(14) = HcclDataType::HCCL_DATA_TYPE_INT8; + batchWriteU32Tensor_(15) = dstRankId; // dst卡 + + SyncFunc(); + uint32_t dstServerOffset = dstserverInd; + uint32_t sendInfoCount = B64_PER_BLOCK * PER_MSG_RDMA_SEND_TIME; + DataCopy(dataBatchWriteInfoTensor_[dstServerOffset * sendInfoCount], batchWriteU64Tensor_, sendInfoCount); + } +} + +// 机间同平面RDMA通信 +template +__aicore__ inline void CamMoeDistributeDispatchA2Layered::DispatchBetweenServer() +{ + ConstructDataAndFlagBatchWriteInfo(); + PipeBarrier(); + SyncAll(); + if ASCEND_IS_AIV { + if (aivId_ == 0) { + HcclHandle batchWriteResultData = hccl_.BatchWrite((GM_ADDR)(dataBatchWriteInfoTensor_.GetPhyAddr()), + serverNum * PER_MSG_RDMA_SEND_TIME); + bufferChosenGlobal_(0) = bufferId_ ^ 1; + DataCacheCleanAndInvalid( + bufferChosenGlobal_); + } + } +} + +template +__aicore__ inline uint32_t +CamMoeDistributeDispatchA2Layered::GetExpRank(uint32_t expertId) +{ + return expertId / localMoeExpertNum_; +} + +template +__aicore__ inline bool +CamMoeDistributeDispatchA2Layered::IsInSameServer(uint32_t targetRankId) +{ + return targetRankId / SERVER_RANK_SIZE == rankId_ / SERVER_RANK_SIZE; +} + +template +__aicore__ inline int64_t +CamMoeDistributeDispatchA2Layered::MergeMagicWithValue(int32_t magic, int32_t value) +{ + return (static_cast(magic) << 32) | static_cast(value); +} + +template +__aicore__ inline void CamMoeDistributeDispatchA2Layered::SetIpcFlag(int32_t flagVal) +{ + if (aivId_ >= SERVER_RANK_SIZE) { + return; + } + uint32_t destRankIdx = aivId_; + uint32_t localRankId = rankId_ % SERVER_RANK_SIZE; + GlobalTensor globalSet; + globalSet.SetGlobalBuffer((__gm__ int64_t *)(shareAddrs[destRankIdx]) + localRankId * B64_PER_BLOCK); + LocalTensor localSet = tBuf.GetWithOffset(B64_PER_BLOCK, 0); + int64_t setVal = MergeMagicWithValue(magicVal_, flagVal); + localSet.SetValue(0, setVal); + SyncFunc(); + DataCopy(globalSet, localSet, B64_PER_BLOCK); +} + +template +__aicore__ inline void CamMoeDistributeDispatchA2Layered::WaitIpcFlag(int32_t flagVal) +{ + int64_t waitVal = MergeMagicWithValue(magicVal_, flagVal); + if (aivId_ >= SERVER_RANK_SIZE) { + return; + } + LocalTensor localWait = tBuf.GetWithOffset(B64_PER_BLOCK, 0); + bool isSync = true; + uint32_t destRankIdx = aivId_; + uint32_t localRankId = rankId_ % SERVER_RANK_SIZE; + GlobalTensor flagIpcGt; + flagIpcGt.SetGlobalBuffer((__gm__ int64_t *)(shareAddrs[localRankId]) + destRankIdx * B64_PER_BLOCK); + PipeBarrier(); + do { + DataCopy(localWait, flagIpcGt, B64_PER_BLOCK); + SyncFunc(); + // 当有core未达到checkValue的阶段时,继续等待 + int64_t tempVal = localWait.GetValue(0); + if (tempVal >= waitVal) { + break; + } + } while (isSync); +} + +template +__aicore__ inline void +CamMoeDistributeDispatchA2Layered::SetTokenCnt(GlobalTensor globalSet) +{ + AscendC::SetAtomicAdd(); + LocalTensor localSet = tBuf.GetWithOffset(EXP_TOKEN_COUNT_FLAG_CNT, 0); + localSet(0) = 1; // AtomicAdd每次+1 + SyncFunc(); + DataCopy(globalSet, localSet, EXP_TOKEN_COUNT_FLAG_CNT); + SyncFunc(); + AscendC::SetAtomicNone(); +} + +template +__aicore__ inline void CamMoeDistributeDispatchA2Layered::WaitWindow() +{ + // 前ServerNum个卡进行等待,等待本server的也保留 + if (aivId_ >= serverNum || aivId_ == (rankId_ / SERVER_RANK_SIZE)) { + return; // 不等待本server + } + uint32_t waitFlagIdx = aivId_; + PipeBarrier(); + LocalTensor statusTensor = statusBuf_.Get(); + while (true) { + DataCopy(statusTensor, readStatusTensor_[(waitFlagIdx)*STATE_OFFSET / sizeof(int32_t)], FLAG_U32_CNT); + SyncFunc(); + int32_t sumOfFlag = statusTensor.GetValue(0); + if (sumOfFlag == FLAG_VALUE) { + break; + } + } +} + +template +__aicore__ inline void CamMoeDistributeDispatchA2Layered::Win2Ipc() +{ + uint32_t coresPerServer = (aivNum_ - serverNum) / serverNum; // 48/2 = 24 + if (aivId_ >= coresPerServer * serverNum) { + return; + } + // 计算本core需要处理的ServerId + uint32_t formServerId = aivId_ / coresPerServer; // 前24处理0, 后24处理1 + + // 获取tokenCnt,计算本卡收到对端server多少Token,用于后续分核计算 + __gm__ uint8_t *tokenCntGlobalAddr; + if (formServerId == rankId_ / SERVER_RANK_SIZE) { + tokenCntGlobalAddr = (__gm__ uint8_t *)(windowOutGM_) + formServerId * SERVER_SIZE_ON_WIN; + } else { + tokenCntGlobalAddr = (__gm__ uint8_t *)(windowInGM_) + formServerId * SERVER_SIZE_ON_WIN; + } + GlobalTensor tokenCntGlobalTensor; + tokenCntGlobalTensor.SetGlobalBuffer((__gm__ uint32_t *)(tokenCntGlobalAddr)); + LocalTensor localWait = tBuf.GetWithOffset(EXP_TOKEN_COUNT_FLAG_CNT, 0); + + DataCopy(localWait, tokenCntGlobalTensor, EXP_TOKEN_COUNT_FLAG_CNT); + SyncFunc(); + uint32_t tokenCnt = localWait.GetValue(0); + + GlobalTensor targetTokenIpcGt; // 对端IPC的TokenTensor,写数据用 + + uint32_t WinInTokenOffset = formServerId * SERVER_SIZE_ON_WIN + TOKEN_COUNT_SIZE; + uint32_t localAivId = aivId_ % coresPerServer; // 0,1,2,3...19 + // 平均每个核处理多少token + uint32_t tokenCntPerAiv = tokenCnt / coresPerServer; // 16/20 + // 平分后剩下多少token + uint32_t tokenCntRemain = tokenCnt % coresPerServer; // 16%20 + // 前面的核共分到了多少剩余 + uint32_t tokenCntPreRemain = (localAivId < tokenCntRemain) ? localAivId : tokenCntRemain; // 小于16为 + // 当前核分到多少token + uint32_t tokenCntCurAiv = (localAivId < tokenCntRemain) ? (tokenCntPerAiv + 1) : tokenCntPerAiv; + + LocalTensor localUB = + tBuf.GetWithOffset(DISPATCH_TOKEN_UB_SIZE / sizeof(uint8_t), TBUF_TEMP_OFFSET); + uint32_t tokenCntInUB = DISPATCH_TOKEN_UB_SIZE / tokenStructLen_; + // ceil div + uint32_t batchCnt = (tokenCntCurAiv + tokenCntInUB - 1) / tokenCntInUB; + for (uint32_t batchIdx = 0; batchIdx < batchCnt; ++batchIdx) { + uint32_t tokenCntInBatch = tokenCntInUB; + if (batchIdx == batchCnt - 1) { + tokenCntInBatch = tokenCntCurAiv - (batchCnt - 1) * tokenCntInUB; + } + // 计算当前Core处理的Token偏移 + uint32_t tokenStruceIdx = localAivId * tokenCntPerAiv + tokenCntPreRemain + batchIdx * tokenCntInUB; + // 等待GM->UB + if (formServerId == rankId_ / SERVER_RANK_SIZE) { + SyncFunc(); + DataCopy(localUB, sendTokensU8Tensor_[WinInTokenOffset + tokenStruceIdx * tokenStructLen_], + tokenCntInBatch * tokenStructLen_); + } else { + SyncFunc(); + DataCopy(localUB, readTokensU8Tensor_[WinInTokenOffset + tokenStruceIdx * tokenStructLen_], + tokenCntInBatch * tokenStructLen_); + } + SyncFunc(); + + for (uint32_t tokenIdx = 0; tokenIdx < tokenCntInBatch; ++tokenIdx) { + // 逐个处理Token to Ipc + uint32_t expPos = tokenIdx * tokenStructLen_ + expOffsetInStruct_; + LocalTensor expInfoTensor = localUB[expPos].ReinterpretCast(); + // 当前Token的ExpIds信息 + uint32_t tokenCntPos = tokenIdx * tokenStructLen_ + cntOffsetInStruct_; + LocalTensor cntInfoTensor = localUB[tokenCntPos].ReinterpretCast(); + // 当前Token的Cnt信息 + for (uint32_t expIdx = 0; expIdx < axisK_; ++expIdx) { + uint32_t targetexpertId = expInfoTensor[expIdx].GetValue(0); + uint32_t targetRankId = GetExpRank(targetexpertId); + if (!IsInSameServer(targetRankId)) { + continue; + } + uint32_t tokenPosInBlock = cntInfoTensor(expIdx); + PipeBarrier(); + // 在IPC的当前Block中,前面还有tokenPosInBlock个Token + uint32_t targetExpOffset = (targetexpertId % localMoeExpertNum_) * worldSize_ * RANK_SIZE_ON_IPC; + // 第几个Exp段 + uint32_t targetServerOffset = formServerId * SERVER_RANK_SIZE * RANK_SIZE_ON_IPC; + // 第几个Server段 + uint32_t targetRankOffset = (rankId_ % SERVER_RANK_SIZE) * RANK_SIZE_ON_IPC; + // 第几个Rank段 + uint32_t targetTokenOffset = tokenPosInBlock * tokenStructLen_; // 第几个Token位 + uint32_t targetOffset = targetExpOffset + targetServerOffset + targetRankOffset + targetTokenOffset; + + targetTokenIpcGt.SetGlobalBuffer( + (__gm__ uint8_t *)(shareAddrs[targetRankId % SERVER_RANK_SIZE] + IPC_DATA_OFFSET + targetOffset)); + PipeBarrier(); + DataCopy(targetTokenIpcGt, localUB[tokenIdx * tokenStructLen_], tokenStructLen_); + // 对应token个数加1 + GlobalTensor targetCntIpcGt; // 对端IPC的CntTensor,统计对端收到的次数 + targetCntIpcGt.SetGlobalBuffer((__gm__ int32_t *)(shareAddrs[targetRankId % SERVER_RANK_SIZE] + + IPC_TOKEN_CNT_OFFSET)); // 前面记录有几个token + uint32_t setTokenCntOffset = (targetexpertId % localMoeExpertNum_) * worldSize_ + + formServerId * SERVER_RANK_SIZE + (rankId_ % SERVER_RANK_SIZE); + SetTokenCnt(targetCntIpcGt[EXP_TOKEN_COUNT_FLAG_CNT * setTokenCntOffset]); + } + } + } +} + +// 每个专家从不同的server块取数据 +template +__aicore__ inline void CamMoeDistributeDispatchA2Layered::Ipc2Out() +{ + uint32_t curRankExpertStart = rankId_ * localMoeExpertNum_; // 9*8=72 + uint32_t curRankExpertEnd = curRankExpertStart + localMoeExpertNum_ - 1; // 72+8-1=79 + + PRINTF("[Ipc2Out] blockIdx %d\n", aivId_); + + // for (int i =0 ; i < 1; ++i) { + // GlobalTensor srcIpcU; + // srcIpcU.SetGlobalBuffer((__gm__ uint8_t *)(shareAddrWins[rankId_]) + i * SERVER_SIZE_ON_WIN); + + // for (int j = 0; j < 4096; ++j) { + // GlobalTensor sendTokenU32; + // sendTokenU32.SetGlobalBuffer((__gm__ int32_t *)((shareAddrWins[rankId_]) + i * SERVER_SIZE_ON_WIN + j * + // tokenStructLen_ + TOKEN_COUNT_SIZE)); AscendC::DumpTensor(sendTokenU32[(expOffsetInStruct_) / 4], 920, + // 32); + + // GlobalTensor sendTokenU32_wt; + // sendTokenU32_wt.SetGlobalBuffer((__gm__ float *)((shareAddrWins[rankId_]) + i * SERVER_SIZE_ON_WIN + j * + // tokenStructLen_ + TOKEN_COUNT_SIZE)); AscendC::DumpTensor(sendTokenU32_wt[(weightOffsetInStruct_) / 4], + // 924, 32); + // } + // } + + LocalTensor weightLtV = weightBuf_.Get(); + for (uint32_t srcRank = 0; srcRank < worldSize_; ++srcRank) { + uint32_t localRankIdx = srcRank % SERVER_RANK_SIZE; // 20%8=4 server上的序号4rank,即第5个 + uint32_t curServerIdx = rankId_ / SERVER_RANK_SIZE; // 9/8=1 server1,即第2个 + uint32_t targetRankId = + localRankIdx + curServerIdx * SERVER_RANK_SIZE; // 4+1*8=12 当前server上的rank,全局rankid=12 + uint32_t tarServerBlockIdx = srcRank / SERVER_RANK_SIZE; // 20/8=2 目标rank上的block序号2,即第3块 + + GlobalTensor srcIpcGt; // TODO: 取地址可能有问题,需要为 targetRankId 的 tarServerBlockIdx 的地址 + srcIpcGt.SetGlobalBuffer((__gm__ uint8_t *)(shareAddrWins[localRankIdx]) + + tarServerBlockIdx * SERVER_SIZE_ON_WIN + TOKEN_COUNT_SIZE); + // srcIpcGt.SetGlobalBuffer((__gm__ uint8_t *)(shareAddrWins[localRankIdx]) + + // tarServerBlockIdx * SERVER_SIZE_ON_WIN); + + for (uint32_t recvExpId = curRankExpertStart; recvExpId <= curRankExpertEnd; ++recvExpId) { + int recvTokenCnt = epRankTokenCntGMTensor_.GetValue(recvExpId * worldSize_ + + srcRank); // 专家recvExpId从srcRank收的token个数 + PRINTF("[Ipc2Out] blockIdx:%d, recvTokenCnt:%d\n", aivId_, recvTokenCnt); + + uint32_t beginIndex = 0; + uint32_t endIndex = 0; + // 分核处理token数量 + + uint32_t tokenCntPerAiv = recvTokenCnt / aivNum_; + uint32_t remainTokenNum = recvTokenCnt % aivNum_; + beginIndex = tokenCntPerAiv * aivId_; + if (aivId_ < remainTokenNum) { + tokenCntPerAiv++; + beginIndex += aivId_; + } else { + beginIndex += remainTokenNum; + } + endIndex = beginIndex + tokenCntPerAiv; + if (beginIndex >= recvTokenCnt) { + continue; + } + LocalTensor localUB = tBuf.GetWithOffset( + (DISPATCH_TOKEN_UB_SIZE - TBUF_TEMP_OFFSET) / sizeof(uint8_t), TBUF_TEMP_OFFSET); + + DataCopyExtParams copyParams{1, static_cast(tokenStructLen_), 0, 0, 0}; + DataCopyPadExtParams padParams; + DataCopyPadExtParams tokenExtParams{false, 0U, 0U, 0U}; + DataCopyExtParams weightParams{1, static_cast(sizeof(float)), 0, 0, 0}; + DataCopyPadExtParams weightExtParams{false, 0U, 0U, 0U}; + DataCopyExtParams scalesParams{1, static_cast(sizeof(float)), 0, 0, 0}; + DataCopyPadExtParams scalesExtParams{false, 0U, 0U, 0U}; + + for (int i = beginIndex; i < endIndex; ++i) { + // 假设当前shape为[expertNum, rank, maxBs],专家recvExpId从srcRank读取第i个token的src与dst + int32_t srcOffset = + srcOffsetRankTokenIdxGMTensor_.GetValue(recvExpId * worldSize_ * BS_UPPER + srcRank * BS_UPPER + i); + int32_t dstOffset = + dstOffsetRankTokenIdxGMTensor_.GetValue(recvExpId * worldSize_ * BS_UPPER + srcRank * BS_UPPER + i); + + uint32_t tokenOffset = + (tokenStructLen_ * srcOffset); // 包含token, 以及token后的信息:expIds, weights, tokenIdx, scales + + DataCopyPad(localUB, srcIpcGt[tokenOffset], copyParams, padParams); // winIn --> local + SyncFunc(); + LocalTensor tokenLt = localUB.ReinterpretCast(); + DataCopyExtParams tokenParams{1, static_cast(tokenLenInStruct_), 0, 0, 0}; + DataCopyPad(expandXOutGMTensor_[dstOffset * axisH_], tokenLt, tokenParams); // local --> out + + LocalTensor expLt = localUB[expOffsetInStruct_].ReinterpretCast(); + SyncFunc(); + int index = 100; + for (int j = 0; j < axisK_; j++) { + if (rankId_ == 0) { + PRINTF("[Ipc2Out] rank:%d, aivId_:%d, topk:%d\n", rankId_, aivId_, expLt.GetValue(j)); + } + + if (expLt.GetValue(j) == recvExpId) { + index = j; + } + } + // if (index == 100) { + // AscendC::DumpTensor(expLt, 1012, 32); + // } + // weight to output + LocalTensor weightLt = localUB[weightOffsetInStruct_].ReinterpretCast(); + float weightVal = weightLt.GetValue(index); + if (index == 100) { + AscendC::DumpTensor(weightLt, 1016, 32); + } + float target = (float)1.0; + if (weightVal != target) { + PRINTF( + "[Ipc2Out] rank:%d, aivId_:%d, curRankExpertStart:%d, curRankExpertEnd:%d, localRankIdx:%d, " + "curServerIdx:%d, targetRankId:%d, tarServerBlockIdx:%d, recvTokenCnt:%d, i:%d, recvExpId:%d, " + "srcRank:%d, srcOffset:%d, dstOffset:%d, tokenOffset:%d, weightVal:%f, index:%d, tokenLt:%d\n", + rankId_, aivId_, curRankExpertStart, curRankExpertEnd, localRankIdx, curServerIdx, targetRankId, + tarServerBlockIdx, recvTokenCnt, i, recvExpId, srcRank, srcOffset, dstOffset, tokenOffset, + weightVal, index, tokenLt(0)); + } + + // weightLt(0) = weightVal; + // AscendC::DumpTensor(weightsOutGt, 1019, 148); + SyncFunc(); + weightLtV(0) = weightLt.GetValue(index); + pipe_barrier(PIPE_ALL); + DataCopyPad(weightsOutGt[dstOffset], weightLtV, weightParams); + pipe_barrier(PIPE_ALL); + + // weightsOutGt.SetValue(dstOffset, weightVal); + // __asm__ __volatile__(""); + // AscendC::DataCacheCleanAndInvalid(weightsOutGt[dstOffset]); + // __asm__ __volatile__(""); + // DataCopyPad(weightsOutGt[dstOffset], weightLt, weightParams); // local --> out + // AscendC::DumpTensor(weightsOutGt, 1023, 148); + + // dynamic scales to output + if constexpr (DynamicQuant) { + LocalTensor quantTempUB = localUB[scaleOffsetInStruct_].ReinterpretCast(); + DataCopyPad(dynamicScalesOutGMTensor_[dstOffset], quantTempUB, scalesParams); + } + SyncFunc(); + } + } + } +} + +template +__aicore__ inline void CamMoeDistributeDispatchA2Layered::CleanUp() // 清除status +{ + uint32_t cleanBuffSize = worldSize_ * localMoeExpertNum_ * TOKEN_COUNT_SIZE; + if (cleanBuffSize < STATE_OFFSET * serverNum) { + cleanBuffSize = STATE_OFFSET * serverNum; + } + LocalTensor cleanTempLt_ = tBuf.GetWithOffset(cleanBuffSize / sizeof(int32_t), TBUF_TEMP_OFFSET); + GlobalTensor flagIpcGt; + Duplicate(cleanTempLt_, 0, cleanBuffSize / sizeof(int32_t)); + PipeBarrier(); + flagIpcGt.SetGlobalBuffer((__gm__ int32_t *)(shareAddrs[rankId_ % SERVER_RANK_SIZE])); + PipeBarrier(); + DataCopy(readStatusTensor_, cleanTempLt_, cleanBuffSize / sizeof(int32_t)); + PipeBarrier(); + DataCopy(flagIpcGt[IPC_TOKEN_CNT_OFFSET / sizeof(int32_t)], cleanTempLt_, cleanBuffSize / sizeof(int32_t)); +} + +template +__aicore__ inline void CamMoeDistributeDispatchA2Layered::Process() +{ + if ASCEND_IS_AIV { // 全aiv处理 + PRINTF("[A2layer Process blockIdx %d]\n", aivId_); + Input2Win(); + PRINTF("[A2layer Input2Win blockIdx %d]\n", aivId_); + PipeBarrier(); + SyncAll(); + PRINTF("[A2layer b4WriteRdmaCntInfo blockIdx %d]\n", aivId_); + WriteRdmaCntInfo(); + PRINTF("[A2layer b4DispatchBetweenServer blockIdx %d]\n", aivId_); + DispatchBetweenServer(); + PRINTF("[A2layer b4WaitWindow blockIdx %d]\n", aivId_); + WaitWindow(); + PRINTF("[A2layer AfterWaitWindow blockIdx %d]\n", aivId_); + PipeBarrier(); + SyncAll(); + PRINTF("[A2layer Win2Ipc blockIdx %d]\n", aivId_); + // 最后serverNum个核不参与Win2Ipc,只进行reduceInfo计算 + if (aivId_ < aivNum_ - serverNum) { + // Win2Ipc(); + } + PipeBarrier(); + SyncAll(); + + PRINTF("[A2layer b4SetIpcFlag blockIdx %d]\n", aivId_); + SetIpcFlag(IPC_FLAG_STEP_1); + PRINTF("[A2layer b4WaitIpcFlag blockIdx %d]\n", aivId_); + WaitIpcFlag(IPC_FLAG_STEP_1); + PRINTF("[A2layer AfterWaitIpcFlag blockIdx %d]\n", aivId_); + PipeBarrier(); + SyncAll(); + PRINTF("[A2layer b4Ipc2Out blockIdx %d]\n", aivId_); + Ipc2Out(); + PRINTF("[A2layer AfterIpc2Out blockIdx %d]\n", aivId_); + PipeBarrier(); + SyncAll(); + + PRINTF("[A2layer b4CleanUp blockIdx %d]\n", aivId_); + if (aivId_ == 0) { + CleanUp(); + } + PRINTF("[A2layer AfterCleanUp blockIdx %d]\n", aivId_); + PipeBarrier(); + SetIpcFlag(IPC_FLAG_STEP_2); // 为何同步? + WaitIpcFlag(IPC_FLAG_STEP_2); + PipeBarrier(); + SyncAll(); + + hccl_.Finalize(); + } +} +} // namespace MoeDistributeDispatchA2Impl +#endif // MOE_DISTRIBUTE_DISPATCH_A2_LAYERED_H diff --git a/csrc/deepep/ops2/op_kernel/cam_moe_distribute_dispatch_tiling.h b/csrc/deepep/ops2/op_kernel/cam_moe_distribute_dispatch_tiling.h new file mode 100644 index 00000000..4b3ea1ca --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/cam_moe_distribute_dispatch_tiling.h @@ -0,0 +1,73 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + */ + +#ifndef ASCENDC_CAM_H_COMM_MOE_DISTRIBUTE_DISPATCH_TILING_H +#define ASCENDC_CAM_H_COMM_MOE_DISTRIBUTE_DISPATCH_TILING_H + +#include +#include "kernel_tiling/kernel_tiling.h" + +namespace Cam { +struct CamMoeDistributeDispatchA2Info { + uint32_t epWorldSize; // epWorldSize + uint32_t tpWorldSize; // tpWorldSize + uint32_t epRankId; // epRankId + uint32_t tpRankId; // tpRankId + uint32_t expertSharedType; // expert type + uint32_t sharedExpertRankNum; // shared expert number + uint32_t moeExpertNum; // moe expert number + uint32_t quantMode; // quant mode + uint32_t globalBs; // globalBs = BS * worldSize + uint32_t bs; // bs + uint32_t k; // k + uint32_t h; // h + uint32_t aivNum; // aivNum + bool isQuant; // whether quant or not + bool reserved1; // reserved + bool reserved2; // reserved + bool reserved3; // reserved + uint64_t totalUbSize; // epWorldSize + uint32_t hcclBufferSize; // HCCL windows, unit:B + uint32_t expertTokenNumsType; // expert token nums type, support 0: cumsum mode, 1: count mode +}; + +struct CamMoeDistributeDispatchA2TilingData { + Mc2InitTiling mc2InitTiling; + Mc2CcTiling mc2CcTiling; + CamMoeDistributeDispatchA2Info moeDistributeDispatchInfo; +}; + +struct CamMoeDistributeDispatchInfo { + uint32_t epWorldSize; // epWorldSize + uint32_t tpWorldSize; // tpWorldSize + uint32_t epRankId; // epRankId + uint32_t tpRankId; // tpRankId + uint32_t expertShardType; // expert type + uint32_t sharedExpertRankNum; // shared expert number + uint32_t moeExpertNum; // moe expert number + uint32_t quantMode; // quant mode + uint32_t globalBs; // globalBs = BS * worldSize + uint32_t bs; // bs + uint32_t k; // k + uint32_t h; // h + uint32_t aivNum; // aivNum + bool isQuant; // whether quant or not + bool reserved1; // reserved + bool reserved2; // reserved + bool reserved3; // reserved + uint64_t totalUbSize; // epWorldSize + uint64_t totalWinSize; + uint32_t expertTokenNumsType; // expert token nums type, support 0: cumsum mode, 1: count mode + uint64_t magic; +}; + +struct CamMoeDistributeDispatchTilingData { + Mc2InitTiling mc2InitTiling; + Mc2CcTiling mc2CcTiling1; + Mc2CcTiling mc2CcTiling2; + CamMoeDistributeDispatchInfo moeDistributeDispatchInfo; +}; +} // namespace Cam + +#endif diff --git a/csrc/deepep/ops2/op_kernel/comm_args.h b/csrc/deepep/ops2/op_kernel/comm_args.h new file mode 100644 index 00000000..fcbf076a --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/comm_args.h @@ -0,0 +1,84 @@ +#ifndef COMM_ARGS_H +#define COMM_ARGS_H +#include +#include + +#define FORCE_INLINE_AICORE __attribute__((always_inline)) inline __aicore__ +#include "kernel_operator.h" + +namespace Moe { +constexpr int CAM_MAX_RANK_SIZE = 384; // Maximum number of NPU cards supported by the communication library + +constexpr uint64_t NOTIFY_DISPATCH_BUFF_OFFSET = 404UL * 1024UL * 1024UL; +constexpr int64_t IPC_BUFF_MAX_SIZE = 200 * 1024 * 1024; +constexpr int64_t IPC_DATA_OFFSET = 2 * 1024 * 1024; // First 2MB as flag, then 100MB as data storage +constexpr int64_t PING_PONG_SIZE = 2; +constexpr int64_t UB_SINGLE_DMA_SIZE_MAX = 190 * 1024; +constexpr int64_t SMALL_DATA_SIZE = 1 * 1024 * 1024; +constexpr int64_t UB_SINGLE_PING_PONG_ADD_SIZE_MAX = UB_SINGLE_DMA_SIZE_MAX / 2; +constexpr int UB_ALIGN_SIZE = 32; +constexpr int64_t MAGIC_ALIGN_COUNT = UB_ALIGN_SIZE / sizeof(int32_t); + +constexpr uint8_t COMM_NUM = 2; // Size of communication domain +constexpr uint8_t COMM_EP_IDX = 0; +constexpr uint8_t COMM_TP_IDX = 1; + +constexpr int DFX_COUNT = 50; +constexpr int64_t WAIT_SUCCESS = 112233445566; +constexpr int64_t IPC_CHUNK_FLAG = 0; // Start offset for send recv, chunk flag region +constexpr int64_t MAX_WAIT_ROUND_UNIT = + 10 * 1000 * 1000; // Threshold for waiting to get Flag under normal conditions within the same SIO + +constexpr static int32_t UB_HEAD_OFFSET = 96; +constexpr static int32_t UB_MID_OFFSET = UB_HEAD_OFFSET + UB_SINGLE_PING_PONG_ADD_SIZE_MAX + UB_ALIGN_SIZE; +constexpr static int64_t UB_FLAG_SIZE = 2 * 1024; +constexpr static int64_t MAX_CORE_NUM = 48; +constexpr static uint64_t STATE_WIN_OFFSET = 900 * 1024; +constexpr static int64_t COMPARE_ALIGN_SIZE = 256; + +constexpr static int64_t UB_SINGLE_TOTAL_SIZE_MAX = 192 * 1024; +constexpr static int64_t START_OFFSET_FOR_SHARE = 512; + +enum Op : int { COPYONLY = -1, ADD = 0, MUL = 1, MAX = 2, MIN = 3 }; + +template +constexpr T T_MAX = std::numeric_limits::max(); + +template +inline __aicore__ T CeilDiv(const T dividend, const T divisor) +{ + static_assert(std::is_arithmetic::value, "T must be an arithmetic type"); + if (divisor == 0 || dividend + divisor - 1 < dividend) { + return T_MAX; + } + return (dividend + divisor - 1) / divisor; +} + +struct CommArgs { + int rank = 0; // attr rank_id, global rank + int localRank = -1; + int rankSize = 0; // global rank size + int localRankSize = -1; // This parameter refers to the number of cards interconnected in fullmesh + uint32_t extraFlag = 0; // 32 bit map, the specific meaning of each bit is above in this file + int testFlag = 0; + GM_ADDR peerMems[CAM_MAX_RANK_SIZE] = + {}; // Buffer obtained from initialization, all allreduce is the same parameter + /** + * @param sendCountMatrix One-dimensional array with a size of rankSize*rankSize + * eg: The value of sendCountMatrix[1] corresponds to the [0][1] of the two-dimensional array, indicating the number + * of data that card 0 needs to send to card 1 + */ + int64_t sendCountMatrix[CAM_MAX_RANK_SIZE * CAM_MAX_RANK_SIZE] = {}; // for all2allvc + int64_t sendCounts[CAM_MAX_RANK_SIZE] = {}; // for all2allv + int64_t sdispls[CAM_MAX_RANK_SIZE] = {}; // for all2allv + int64_t recvCounts[CAM_MAX_RANK_SIZE] = {}; // for all2allv + int64_t rdispls[CAM_MAX_RANK_SIZE] = {}; // for all2allv + int64_t batchSize; + int64_t hiddenSize; + int64_t topk; + int64_t sharedExpertRankNum; + int64_t expertNumPerRank; + int64_t dfx[DFX_COUNT] = {}; +}; +} // namespace Moe +#endif // COMM_ARGS_H diff --git a/csrc/deepep/ops2/op_kernel/data_copy.h b/csrc/deepep/ops2/op_kernel/data_copy.h new file mode 100644 index 00000000..47443e67 --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/data_copy.h @@ -0,0 +1,68 @@ +#ifndef CAM_DATACOPY_GM2GM_H +#define CAM_DATACOPY_GM2GM_H +#include +#include "comm_args.h" + +using namespace AscendC; +using namespace Moe; + +template +FORCE_INLINE_AICORE void SetAtomicOpType(int op) +{ + switch (op) { + case ADD: + AscendC::SetAtomicAdd(); + break; + case MUL: + // Ignore setting the atomic register when performing mul + break; + case MAX: + AscendC::SetAtomicMax(); + break; + case MIN: + AscendC::SetAtomicMin(); + break; + default: + AscendC::SetAtomicNone(); + } +} + +template +FORCE_INLINE_AICORE void CpUB2GM(__gm__ T *gmAddr, __ubuf__ T *ubAddr, uint32_t size) +{ + LocalTensor ubTensor; + GlobalTensor gmTensor; + DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); + ubTensor.address_.logicPos = static_cast(TPosition::VECIN); + ubTensor.address_.bufferAddr = reinterpret_cast(ubAddr); + gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint8_t *>(gmAddr)); + DataCopyPad(gmTensor, ubTensor, dataCopyParams); +} + +template +FORCE_INLINE_AICORE void CpGM2UB(__ubuf__ T *ubAddr, __gm__ T *gmAddr, uint32_t size) +{ + LocalTensor ubTensor; + GlobalTensor gmTensor; + DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); + ubTensor.address_.logicPos = static_cast(TPosition::VECIN); + ubTensor.address_.bufferAddr = reinterpret_cast(ubAddr); + gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint8_t *>(gmAddr)); + DataCopyPadExtParams padParams; + DataCopyPad(ubTensor, gmTensor, dataCopyParams, padParams); +} + +template +FORCE_INLINE_AICORE void CopyUB2UB(__ubuf__ T *dst, __ubuf__ T *src, const uint32_t calCount) +{ + LocalTensor srcTensor; + LocalTensor dstTensor; + TBuffAddr srcAddr, dstAddr; + srcAddr.bufferAddr = reinterpret_cast(src); + dstAddr.bufferAddr = reinterpret_cast(dst); + srcTensor.SetAddr(srcAddr); + dstTensor.SetAddr(dstAddr); + DataCopy(dstTensor, srcTensor, calCount); +} + +#endif // CAM_DATACOPY_GM2GM_H diff --git a/csrc/deepep/ops2/op_kernel/dispatch_layout.cpp b/csrc/deepep/ops2/op_kernel/dispatch_layout.cpp new file mode 100644 index 00000000..a0e6450b --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/dispatch_layout.cpp @@ -0,0 +1,27 @@ +#include "kernel_operator.h" +#include "dispatch_layout.h" +#include "dispatch_layout_a2.h" +#include "dispatch_layout_tiling.h" + +#define TILING_KEY_INT 23 +#define TILING_KEY_A2_INT 123 + +extern "C" __global__ __aicore__ void dispatch_layout(GM_ADDR topkIdx, GM_ADDR numTokensPerRank, + GM_ADDR numTokensPerExpert, GM_ADDR isTokenInRank, + GM_ADDR totalData, GM_ADDR workspace, GM_ADDR tiling) +{ + REGISTER_TILING_DEFAULT(DispatchLayoutTilingData); + GET_TILING_DATA_WITH_STRUCT(DispatchLayoutTilingData, tilingData, tiling); + + TPipe pipe; + + if (TILING_KEY_IS(TILING_KEY_INT)) { + MoeDispatchLayout::DispatchLayout op; + op.Init(topkIdx, numTokensPerRank, numTokensPerExpert, isTokenInRank, totalData, workspace, &pipe, &tilingData); + op.Process(); + } else if (TILING_KEY_IS(TILING_KEY_A2_INT)) { + MoeDispatchLayoutA2::DispatchLayoutA2 op; + op.Init(topkIdx, numTokensPerRank, numTokensPerExpert, isTokenInRank, totalData, workspace, &pipe, &tilingData); + op.Process(); + } +} diff --git a/csrc/deepep/ops2/op_kernel/dispatch_layout.h b/csrc/deepep/ops2/op_kernel/dispatch_layout.h new file mode 100644 index 00000000..5aa54de6 --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/dispatch_layout.h @@ -0,0 +1,157 @@ +#ifndef DISPATCH_LAYOUT_H +#define DISPATCH_LAYOUT_H + +#include +#include "kernel_operator.h" + +#include "comm_args.h" +#include "data_copy.h" +#include "sync_collectives.h" +#include "moe_distribute_base.h" +#include "dispatch_layout_tiling.h" +namespace MoeDispatchLayout { + +constexpr uint32_t UB_32_ALIGN = 32U; + +template +__aicore__ inline void SyncFunc() +{ + int32_t eventID = static_cast(GetTPipePtr()->FetchEventID(event)); + AscendC::SetFlag(eventID); + AscendC::WaitFlag(eventID); +} + +using namespace AscendC; +using namespace Moe; +template +class DispatchLayout +{ +public: + __aicore__ inline DispatchLayout(){}; + + __aicore__ inline void Init(GM_ADDR topkIdx, GM_ADDR numTokensPerRank, GM_ADDR numTokensPerExpert, + GM_ADDR isTokenInRank, GM_ADDR totalData, GM_ADDR workspace, TPipe *pipe, + const DispatchLayoutTilingData *tilingData) + { + numTokens_ = tilingData->dispatchLayoutInfo.numTokens; + numRanks_ = tilingData->dispatchLayoutInfo.numRanks; + numExperts_ = tilingData->dispatchLayoutInfo.numExperts; + numTopk_ = tilingData->dispatchLayoutInfo.numTopk; + tpipe_ = pipe; + + coreIdx_ = GetBlockIdx(); + uint32_t maxAivNum = GetBlockNum(); + aivNum_ = numTokens_ <= maxAivNum ? numTokens_ : maxAivNum; + uint32_t temp = numTokens_ / aivNum_; + uint32_t restNum = numTokens_ % aivNum_; + int64_t topkIdxOffset; + int64_t isTokenOffset; + tempTokens_ = temp; + if (coreIdx_ < aivNum_) { + if (coreIdx_ < restNum) { + tempTokens_++; + } + topkIdx32AlignIntLen_ = Ceil(tempTokens_ * numTopk_ * sizeof(int64_t), UB_32_ALIGN) * UB_32_ALIGN; + numTokensPerRank32AlignIntLen_ = Ceil(numRanks_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + numTokensPerExpert32AlignIntLen_ = Ceil(numExperts_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + isTokenInRank32AlignIntLen_ = Ceil(tempTokens_ * numRanks_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + + if (coreIdx_ < restNum) { + topkIdxOffset = coreIdx_ * tempTokens_ * numTopk_ * sizeof(int64_t); + isTokenOffset = coreIdx_ * tempTokens_ * numRanks_ * sizeof(T); + } else { + topkIdxOffset = (restNum + coreIdx_ * tempTokens_) * numTopk_ * sizeof(int64_t); + isTokenOffset = (restNum + coreIdx_ * tempTokens_) * numRanks_ * sizeof(T); + } + + topkIdxGM_.SetGlobalBuffer((__gm__ int64_t *)(topkIdx + topkIdxOffset)); + numTokensPerRankGM_.SetGlobalBuffer((__gm__ T *)numTokensPerRank); + numTokensPerExpertGM_.SetGlobalBuffer((__gm__ T *)numTokensPerExpert); + isTokenInRankGM_.SetGlobalBuffer((__gm__ T *)(isTokenInRank + isTokenOffset)); + } + } + + __aicore__ inline void Process() + { + tpipe_->Reset(); + if (coreIdx_ < aivNum_) { + tpipe_->InitBuffer(topkIdxBuf_, topkIdx32AlignIntLen_); + tpipe_->InitBuffer(numTokensPerRankBuf_, numTokensPerRank32AlignIntLen_); + tpipe_->InitBuffer(numTokensPerExpertBuf_, numTokensPerExpert32AlignIntLen_); + tpipe_->InitBuffer(isTokenInRankBuf_, isTokenInRank32AlignIntLen_); + tpipe_->InitBuffer(seenRankBuf_, numRanks_ * sizeof(T)); + + LocalTensor topkIdxTensor = topkIdxBuf_.AllocTensor(); + const DataCopyExtParams dataCopyParams{1U, topkIdx32AlignIntLen_, 0U, 0U, 0U}; + const DataCopyPadExtParams padParams{false, 0U, 0U, 0U}; + DataCopyPad(topkIdxTensor, topkIdxGM_, dataCopyParams, padParams); + SyncFunc(); + + LocalTensor numTokensPerRankTensor = numTokensPerRankBuf_.AllocTensor(); + LocalTensor numTokensPerExpertTensor = numTokensPerExpertBuf_.AllocTensor(); + LocalTensor isTokenInRankTensor = isTokenInRankBuf_.AllocTensor(); + LocalTensor seenRankTensor = seenRankBuf_.AllocTensor(); + Duplicate(numTokensPerRankTensor, 0, numRanks_); + Duplicate(numTokensPerExpertTensor, 0, numExperts_); + Duplicate(isTokenInRankTensor, 0, tempTokens_ * numRanks_); + SyncFunc(); + + int experts_per_rank = numExperts_ / numRanks_; + for (int i = 0; i < tempTokens_; ++i) { + SyncFunc(); + Duplicate(seenRankTensor, 0, numRanks_); + SyncFunc(); + for (int j = 0; j < numTopk_; ++j) { + int64_t expert_idx = topkIdxTensor.GetValue(i * numTopk_ + j); + uint32_t per_expert_num = numTokensPerExpertTensor.GetValue(expert_idx) + 1; + numTokensPerExpertTensor.SetValue(expert_idx, per_expert_num); + int rank_id = expert_idx / experts_per_rank; + if (!seenRankTensor.GetValue(rank_id)) { + uint32_t per_rank_num = numTokensPerRankTensor.GetValue(rank_id) + 1; + isTokenInRankTensor.SetValue(i * numRanks_ + rank_id, 1); + seenRankTensor.SetValue(rank_id, 1); + numTokensPerRankTensor.SetValue(rank_id, per_rank_num); + } + } + } + + const DataCopyExtParams isTokenInRankDataCopyParams{1U, isTokenInRank32AlignIntLen_, 0U, 0U, 0U}; + DataCopyPad(isTokenInRankGM_, isTokenInRankTensor, isTokenInRankDataCopyParams); + AscendC::SetAtomicAdd(); + const DataCopyExtParams numTokensPerRankDataCopyParams{1U, numTokensPerRank32AlignIntLen_, 0U, 0U, 0U}; + DataCopyPad(numTokensPerRankGM_, numTokensPerRankTensor, numTokensPerRankDataCopyParams); + const DataCopyExtParams numTokensPerExpertDataCopyParams{1U, numTokensPerExpert32AlignIntLen_, 0U, 0U, 0U}; + DataCopyPad(numTokensPerExpertGM_, numTokensPerExpertTensor, numTokensPerExpertDataCopyParams); + AscendC::SetAtomicNone(); + } + } + +private: + GlobalTensor topkIdxGM_; + GlobalTensor numTokensPerRankGM_; + GlobalTensor numTokensPerExpertGM_; + GlobalTensor isTokenInRankGM_; + + TBuf<> topkIdxBuf_; + TBuf<> numTokensPerRankBuf_; + TBuf<> numTokensPerExpertBuf_; + TBuf<> isTokenInRankBuf_; + TBuf<> seenRankBuf_; + + TPipe *tpipe_{nullptr}; + uint32_t numTokens_{0}; + uint32_t numRanks_{0}; + uint32_t numExperts_{0}; + uint32_t numTopk_{0}; + uint32_t coreIdx_{0}; + uint32_t aivNum_{0}; + uint32_t tempTokens_{0}; + + uint32_t topkIdx32AlignIntLen_{0}; + uint32_t numTokensPerRank32AlignIntLen_{0}; + uint32_t numTokensPerExpert32AlignIntLen_{0}; + uint32_t isTokenInRank32AlignIntLen_{0}; +}; +} // namespace MoeDispatchLayout + +#endif // DISPATCH_LAYOUT_H diff --git a/csrc/deepep/ops2/op_kernel/dispatch_layout_a2.h b/csrc/deepep/ops2/op_kernel/dispatch_layout_a2.h new file mode 100644 index 00000000..76d17046 --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/dispatch_layout_a2.h @@ -0,0 +1,348 @@ +#ifndef DISPATCH_LAYOUT_A2_H +#define DISPATCH_LAYOUT_A2_H + +#include +#include "kernel_operator.h" + +#include "comm_args.h" +#include "data_copy.h" +#include "sync_collectives.h" +#include "moe_distribute_base.h" +#include "dispatch_layout_tiling.h" + +namespace MoeDispatchLayoutA2 { + +constexpr uint32_t UB_32_ALIGN = 32U; +constexpr uint32_t MAX_BATCH_SIZE = 4096U; +constexpr uint32_t TEMP_BATCH_SIZE = 8U; + +template +__aicore__ inline void SyncFunc() +{ + int32_t eventID = static_cast(GetTPipePtr()->FetchEventID(event)); + AscendC::SetFlag(eventID); + AscendC::WaitFlag(eventID); +} + +using namespace AscendC; +using namespace Moe; +template +class DispatchLayoutA2 +{ +public: + __aicore__ inline DispatchLayoutA2(){}; + + __aicore__ inline void Init(GM_ADDR topkIdx, GM_ADDR numTokensPerRank, GM_ADDR numTokensPerExpert, + GM_ADDR isTokenInRank, GM_ADDR totalData, GM_ADDR workspace, TPipe *pipe, + const DispatchLayoutTilingData *tilingData) + { + numTokens_ = tilingData->dispatchLayoutInfo.numTokens; + numRanks_ = tilingData->dispatchLayoutInfo.numRanks; + numExperts_ = tilingData->dispatchLayoutInfo.numExperts; + numTopk_ = tilingData->dispatchLayoutInfo.numTopk; + localRankSize_ = tilingData->dispatchLayoutInfo.localRankSize; + serverNum_ = (numRanks_ + localRankSize_ - 1) / localRankSize_; + tpipe_ = pipe; + + coreIdx_ = GetBlockIdx(); + uint32_t maxAivNum = GetBlockNum() - 1; + aivNum_ = numTokens_ <= maxAivNum ? numTokens_ : maxAivNum; + uint32_t temp = numTokens_ / aivNum_; + uint32_t restNum = numTokens_ % aivNum_; + int64_t topkIdxOffset; + int64_t isTokenOffset; + int64_t serverOffsetOffset; + int64_t serverNumOffset; + tempTokens_ = temp; + + if (coreIdx_ < aivNum_) { + if (coreIdx_ < restNum) { + tempTokens_++; + } + topkIdx32AlignIntLen_ = Ceil(tempTokens_ * numTopk_ * sizeof(int64_t), UB_32_ALIGN) * UB_32_ALIGN; + numTokensPerRank32AlignIntLen_ = Ceil(numRanks_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + numTokensPerExpert32AlignIntLen_ = Ceil(numExperts_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + isTokenInRank32AlignIntLen_ = Ceil(tempTokens_ * numRanks_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + localTokenServerOffset32AlignIntLen_ = + Ceil(tempTokens_ * serverNum_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + localTokenServerUniqCount32AlignIntLen_ = Ceil(serverNum_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + localTokenServerTotalCount32AlignIntLen_ = + Ceil(tempTokens_ * serverNum_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + localTokenServerNum32AlignIntLen_ = Ceil(tempTokens_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + + if (coreIdx_ < restNum) { + topkIdxOffset = coreIdx_ * tempTokens_ * numTopk_ * sizeof(int64_t); + isTokenOffset = coreIdx_ * tempTokens_ * numRanks_ * sizeof(T); + serverOffsetOffset = coreIdx_ * tempTokens_ * serverNum_ * sizeof(T); + serverNumOffset = coreIdx_ * tempTokens_ * sizeof(T); + } else { + topkIdxOffset = (restNum + coreIdx_ * tempTokens_) * numTopk_ * sizeof(int64_t); + isTokenOffset = (restNum + coreIdx_ * tempTokens_) * numRanks_ * sizeof(T); + serverOffsetOffset = (restNum + coreIdx_ * tempTokens_) * serverNum_ * sizeof(T); + serverNumOffset = (restNum + coreIdx_ * tempTokens_) * sizeof(T); + } + + topkIdxGM_.SetGlobalBuffer((__gm__ int64_t *)(topkIdx + topkIdxOffset)); + numTokensPerRankGM_.SetGlobalBuffer((__gm__ T *)numTokensPerRank); + numTokensPerExpertSrcGM_.SetGlobalBuffer((__gm__ T *)numTokensPerExpert); + numTokensPerExpertGM_.SetGlobalBuffer((__gm__ T *)totalData); + isTokenInRankGM_.SetGlobalBuffer((__gm__ T *)(isTokenInRank + isTokenOffset)); + localTokenServerUniqCountGM_.SetGlobalBuffer((__gm__ T *)(totalData) + numExperts_); + localTokenServerTotalCountGM_.SetGlobalBuffer((__gm__ T *)(totalData + serverOffsetOffset) + numExperts_ + + serverNum_); + localTokenServerNumGM_.SetGlobalBuffer((__gm__ T *)(totalData + serverNumOffset) + numExperts_ + + serverNum_ * (MAX_BATCH_SIZE + 1)); + localTokenServerOffsetGM_.SetGlobalBuffer((__gm__ T *)(totalData + serverOffsetOffset) + numExperts_ + + serverNum_ + MAX_BATCH_SIZE * (serverNum_ + 1)); + } + if (coreIdx_ == aivNum_) { + expertRankTokenIdx32AlignIntLen_ = + Ceil(numExperts_ * TEMP_BATCH_SIZE * 2 * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + localTokenServerOffset32AlignIntLen_ = Ceil(numTokens_ * serverNum_ * sizeof(T), UB_32_ALIGN) * UB_32_ALIGN; + topkIdxGM_.SetGlobalBuffer((__gm__ int64_t *)topkIdx); + localTokenServerOffsetGM_.SetGlobalBuffer((__gm__ T *)totalData + numExperts_ + serverNum_ + + MAX_BATCH_SIZE * (serverNum_ + 1)); + sendTokenIdxGM_.SetGlobalBuffer((__gm__ T *)totalData + numExperts_ + serverNum_ + + MAX_BATCH_SIZE * (1 + 2 * serverNum_)); + expertRankTokenIdxGM_.SetGlobalBuffer((__gm__ T *)totalData + numExperts_ + serverNum_ + + MAX_BATCH_SIZE * (1 + 2 * serverNum_ + numExperts_)); + } + } + + __aicore__ inline void Process() + { + if (coreIdx_ < aivNum_) { + MultiCoreCompute(); + } + SyncAll(); + if (coreIdx_ == aivNum_) { + ComputeServerOffset(); + } + PRINTF("======[layout] block:%d \n", coreIdx_); + } + +private: + __aicore__ inline void MultiCoreCompute() + { + tpipe_->Reset(); + tpipe_->InitBuffer(topkIdxBuf_, topkIdx32AlignIntLen_); + tpipe_->InitBuffer(numTokensPerRankBuf_, numTokensPerRank32AlignIntLen_); + tpipe_->InitBuffer(numTokensPerExpertBuf_, numTokensPerExpert32AlignIntLen_); + tpipe_->InitBuffer(isTokenInRankBuf_, isTokenInRank32AlignIntLen_); + tpipe_->InitBuffer(localTokenServerOffsetBuf_, localTokenServerOffset32AlignIntLen_); + tpipe_->InitBuffer(localTokenServerUniqCountBuf_, localTokenServerUniqCount32AlignIntLen_); + tpipe_->InitBuffer(localTokenServerTotalCountBuf_, localTokenServerTotalCount32AlignIntLen_); + tpipe_->InitBuffer(localTokenServerNumBuf_, localTokenServerNum32AlignIntLen_); + tpipe_->InitBuffer(seenRankBuf_, numRanks_ * sizeof(T)); + tpipe_->InitBuffer(seenServerBuf_, serverNum_ * sizeof(T)); + LocalTensor topkIdxTensor = topkIdxBuf_.AllocTensor(); + const DataCopyExtParams dataCopyParams{1U, topkIdx32AlignIntLen_, 0U, 0U, 0U}; + const DataCopyPadExtParams padParams{false, 0U, 0U, 0U}; + DataCopyPad(topkIdxTensor, topkIdxGM_, dataCopyParams, padParams); + SyncFunc(); + LocalTensor numTokensPerRankTensor = numTokensPerRankBuf_.AllocTensor(); + LocalTensor numTokensPerExpertTensor = numTokensPerExpertBuf_.AllocTensor(); + LocalTensor isTokenInRankTensor = isTokenInRankBuf_.AllocTensor(); + LocalTensor localTokenServerOffsetTensor = localTokenServerOffsetBuf_.AllocTensor(); + LocalTensor localTokenServerUniqCountTensor = localTokenServerUniqCountBuf_.AllocTensor(); + LocalTensor localTokenServerTotalCountTensor = localTokenServerTotalCountBuf_.AllocTensor(); + LocalTensor localTokenServerNumTensor = localTokenServerNumBuf_.AllocTensor(); + LocalTensor seenRankTensor = seenRankBuf_.AllocTensor(); + LocalTensor seenServerTensor = seenServerBuf_.AllocTensor(); + Duplicate(numTokensPerRankTensor, 0, numRanks_); + Duplicate(numTokensPerExpertTensor, 0, numExperts_); + Duplicate(isTokenInRankTensor, 0, tempTokens_ * numRanks_); + Duplicate(localTokenServerOffsetTensor, -1, tempTokens_ * serverNum_); + Duplicate(localTokenServerUniqCountTensor, 0, serverNum_); + Duplicate(localTokenServerTotalCountTensor, 0, tempTokens_ * serverNum_); + Duplicate(localTokenServerNumTensor, 0, tempTokens_); + SyncFunc(); + int experts_per_rank = numExperts_ / numRanks_; + for (int i = 0; i < tempTokens_; ++i) { + SyncFunc(); + Duplicate(seenRankTensor, 0, numRanks_); + Duplicate(seenServerTensor, 0, serverNum_); + SyncFunc(); + for (int j = 0; j < numTopk_; ++j) { + int64_t expert_idx = topkIdxTensor.GetValue(i * numTopk_ + j); + uint32_t per_expert_num = numTokensPerExpertTensor.GetValue(expert_idx) + 1; + numTokensPerExpertTensor.SetValue(expert_idx, per_expert_num); + int rank_id = expert_idx / experts_per_rank; + int server_id = rank_id / localRankSize_; + if (!seenServerTensor.GetValue(server_id)) { + localTokenServerOffsetTensor.SetValue(i * serverNum_ + server_id, 1); + uint32_t uniqCount = localTokenServerUniqCountTensor.GetValue(server_id); + localTokenServerUniqCountTensor.SetValue(server_id, uniqCount + 1); + seenServerTensor.SetValue(server_id, 1); + uint32_t sendServerNum = localTokenServerNumTensor.GetValue(i); + localTokenServerNumTensor.SetValue(i, sendServerNum + 1); + } + uint32_t totalCount = localTokenServerTotalCountTensor.GetValue(i * serverNum_ + server_id) + 1; + localTokenServerTotalCountTensor.SetValue(i * serverNum_ + server_id, totalCount); + if (!seenRankTensor.GetValue(rank_id)) { + uint32_t per_rank_num = numTokensPerRankTensor.GetValue(rank_id) + 1; + isTokenInRankTensor.SetValue(i * numRanks_ + rank_id, 1); + seenRankTensor.SetValue(rank_id, 1); + numTokensPerRankTensor.SetValue(rank_id, per_rank_num); + } + } + } + uint32_t sendSize = tempTokens_ * numRanks_ * sizeof(T); + const DataCopyExtParams isTokenInRankDataCopyParams{1U, sendSize, 0U, 0U, 0U}; + sendSize = tempTokens_ * sizeof(T); + DataCopyPad(isTokenInRankGM_, isTokenInRankTensor, isTokenInRankDataCopyParams); + const DataCopyExtParams localTokenServerNumParams{1U, sendSize, 0U, 0U, 0U}; + DataCopyPad(localTokenServerNumGM_, localTokenServerNumTensor, localTokenServerNumParams); + sendSize = tempTokens_ * serverNum_ * sizeof(T); + const DataCopyExtParams localTokenServerOffsetParams{1U, sendSize, 0U, 0U, 0U}; + DataCopyPad(localTokenServerOffsetGM_, localTokenServerOffsetTensor, localTokenServerOffsetParams); + const DataCopyExtParams localTokenServerTotalCountParams{1U, sendSize, 0U, 0U, 0U}; + DataCopyPad(localTokenServerTotalCountGM_, localTokenServerTotalCountTensor, localTokenServerTotalCountParams); + sendSize = serverNum_ * sizeof(T); + AscendC::SetAtomicAdd(); + const DataCopyExtParams localTokenServerUniqCountParams{1U, sendSize, 0U, 0U, 0U}; + DataCopyPad(localTokenServerUniqCountGM_, localTokenServerUniqCountTensor, localTokenServerUniqCountParams); + const DataCopyExtParams numTokensPerRankDataCopyParams{1U, numTokensPerRank32AlignIntLen_, 0U, 0U, 0U}; + DataCopyPad(numTokensPerRankGM_, numTokensPerRankTensor, numTokensPerRankDataCopyParams); + const DataCopyExtParams numTokensPerExpertDataCopyParams{1U, numTokensPerExpert32AlignIntLen_, 0U, 0U, 0U}; + DataCopyPad(numTokensPerExpertGM_, numTokensPerExpertTensor, numTokensPerExpertDataCopyParams); + DataCopyPad(numTokensPerExpertSrcGM_, numTokensPerExpertTensor, numTokensPerExpertDataCopyParams); + AscendC::SetAtomicNone(); + } + + __aicore__ inline void ComputeServerOffset() + { + tpipe_->Reset(); + tpipe_->InitBuffer(localTokenServerOffsetBuf_, localTokenServerOffset32AlignIntLen_); + tpipe_->InitBuffer(seenServerBuf_, serverNum_ * sizeof(T)); + tpipe_->InitBuffer(expertRankTokenIdxBuf_, expertRankTokenIdx32AlignIntLen_); + tpipe_->InitBuffer(countExpertBuf_, numExperts_ * sizeof(T)); + LocalTensor localTokenServerOffsetTensor = localTokenServerOffsetBuf_.AllocTensor(); + LocalTensor seenServerTensor = seenServerBuf_.AllocTensor(); + const DataCopyExtParams dataCopyParams{1U, localTokenServerOffset32AlignIntLen_, 0U, 0U, 0U}; + const DataCopyPadExtParams padParams{false, 0U, 0U, 0U}; + DataCopyPad(localTokenServerOffsetTensor, localTokenServerOffsetGM_, dataCopyParams, padParams); + SyncFunc(); + Duplicate(seenServerTensor, 0, serverNum_); + SyncFunc(); + for (int i = 0; i < numTokens_; i++) { + for (int j = 0; j < serverNum_; j++) { + int32_t value = localTokenServerOffsetTensor.GetValue(i * serverNum_ + j); + if (value > 0) { + int32_t offset = seenServerTensor.GetValue(j); + localTokenServerOffsetTensor.SetValue(i * serverNum_ + j, offset); + seenServerTensor.SetValue(j, offset + 1); + } + } + } + SyncFunc(); + DataCopyPad(localTokenServerOffsetGM_, localTokenServerOffsetTensor, dataCopyParams); + LocalTensor countExpertTensor = countExpertBuf_.AllocTensor(); + LocalTensor expertRankTokenIdxTensor = expertRankTokenIdxBuf_.AllocTensor(); + Duplicate(countExpertTensor, 0, numExperts_); + SyncFunc(); + int32_t experts_per_rank = numExperts_ / numRanks_; + for (int i = 0; i < numTokens_; i++) { + for (int j = 0; j < numTopk_; j++) { + int32_t expert_id = topkIdxGM_.GetValue(i * numTopk_ + j); + int32_t server_id = (expert_id / experts_per_rank) / localRankSize_; + int32_t offset = localTokenServerOffsetTensor.GetValue(i * serverNum_ + server_id); + int32_t count = countExpertTensor.GetValue(expert_id); + expertRankTokenIdxTensor.SetValue(expert_id * TEMP_BATCH_SIZE + count % TEMP_BATCH_SIZE, offset); + expertRankTokenIdxTensor.SetValue((numExperts_ + expert_id) * TEMP_BATCH_SIZE + count % TEMP_BATCH_SIZE, + i); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + sendTokenIdxGM_[i * numExperts_ + expert_id]); + __asm__ __volatile__(""); + sendTokenIdxGM_.SetValue(i * numExperts_ + expert_id, count); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + sendTokenIdxGM_[i * numExperts_ + expert_id]); + __asm__ __volatile__(""); + count++; + countExpertTensor.SetValue(expert_id, count); + if (count % TEMP_BATCH_SIZE == 0) { + SyncFunc(); + const DataCopyExtParams expertRankTokendataCopyParams{1U, TEMP_BATCH_SIZE * sizeof(T), 0U, 0U, 0U}; + DataCopyPad(expertRankTokenIdxGM_[expert_id * MAX_BATCH_SIZE + count - TEMP_BATCH_SIZE], + expertRankTokenIdxTensor[expert_id * TEMP_BATCH_SIZE], expertRankTokendataCopyParams); + DataCopyPad( + expertRankTokenIdxGM_[(numExperts_ + expert_id) * MAX_BATCH_SIZE + count - TEMP_BATCH_SIZE], + expertRankTokenIdxTensor[(numExperts_ + expert_id) * TEMP_BATCH_SIZE], + expertRankTokendataCopyParams); + SyncFunc(); + Duplicate(expertRankTokenIdxTensor[expert_id * TEMP_BATCH_SIZE], 0, TEMP_BATCH_SIZE); + Duplicate(expertRankTokenIdxTensor[(numExperts_ + expert_id) * TEMP_BATCH_SIZE], 0, + TEMP_BATCH_SIZE); + SyncFunc(); + } + } + } + for (int i = 0; i < numExperts_; i++) { + int32_t count = countExpertTensor.GetValue(i); + uint32_t rest = count % TEMP_BATCH_SIZE; + if (rest) { + SyncFunc(); + const DataCopyExtParams expertRankTokendataCopyParams{1U, uint32_t(rest * sizeof(T)), 0U, 0U, 0U}; + DataCopyPad(expertRankTokenIdxGM_[i * MAX_BATCH_SIZE + count - rest], + expertRankTokenIdxTensor[i * TEMP_BATCH_SIZE], expertRankTokendataCopyParams); + DataCopyPad(expertRankTokenIdxGM_[(i + numExperts_) * MAX_BATCH_SIZE + count - rest], + expertRankTokenIdxTensor[(i + numExperts_) * TEMP_BATCH_SIZE], + expertRankTokendataCopyParams); + SyncFunc(); + } + } + } + + GlobalTensor topkIdxGM_; + GlobalTensor numTokensPerRankGM_; + GlobalTensor numTokensPerExpertGM_; + GlobalTensor numTokensPerExpertSrcGM_; + GlobalTensor isTokenInRankGM_; + GlobalTensor localTokenServerOffsetGM_; + GlobalTensor localTokenServerUniqCountGM_; + GlobalTensor localTokenServerTotalCountGM_; + GlobalTensor localTokenServerNumGM_; + GlobalTensor expertRankTokenIdxGM_; + GlobalTensor sendTokenIdxGM_; + + TBuf<> topkIdxBuf_; + TBuf<> numTokensPerRankBuf_; + TBuf<> numTokensPerExpertBuf_; + TBuf<> isTokenInRankBuf_; + TBuf<> localTokenServerOffsetBuf_; + TBuf<> localTokenServerUniqCountBuf_; + TBuf<> localTokenServerTotalCountBuf_; + TBuf<> localTokenServerNumBuf_; + TBuf<> seenRankBuf_; + TBuf<> seenServerBuf_; + TBuf<> countExpertBuf_; + TBuf<> expertRankTokenIdxBuf_; + + TPipe *tpipe_{nullptr}; + uint32_t numTokens_{0}; + uint32_t numRanks_{0}; + uint32_t numExperts_{0}; + uint32_t numTopk_{0}; + uint32_t localRankSize_{0}; + uint32_t serverNum_{0}; + uint32_t coreIdx_{0}; + uint32_t aivNum_{0}; + uint32_t tempTokens_{0}; + + uint32_t topkIdx32AlignIntLen_{0}; + uint32_t numTokensPerRank32AlignIntLen_{0}; + uint32_t numTokensPerExpert32AlignIntLen_{0}; + uint32_t isTokenInRank32AlignIntLen_{0}; + uint32_t localTokenServerOffset32AlignIntLen_{0}; + uint32_t localTokenServerUniqCount32AlignIntLen_{0}; + uint32_t localTokenServerTotalCount32AlignIntLen_{0}; + uint32_t localTokenServerNum32AlignIntLen_{0}; + uint32_t expertRankTokenIdx32AlignIntLen_{0}; +}; +} // namespace MoeDispatchLayoutA2 + +#endif // DISPATCH_LAYOUT_A2_H diff --git a/csrc/deepep/ops2/op_kernel/dispatch_layout_tiling.h b/csrc/deepep/ops2/op_kernel/dispatch_layout_tiling.h new file mode 100644 index 00000000..af1d0eae --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/dispatch_layout_tiling.h @@ -0,0 +1,21 @@ +#ifndef DISPATCH_LAYOUT_TILING_H +#define DISPATCH_LAYOUT_TILING_H + +#include "kernel_tiling/kernel_tiling.h" + +struct DispatchLayoutInfo { + uint32_t numTokens; + uint32_t numRanks; + uint32_t numExperts; + uint32_t numTopk; + uint32_t localRankSize; + uint64_t totalUbSize; +}; + +struct DispatchLayoutTilingData { + Mc2InitTiling mc2InitTiling; + Mc2CcTiling mc2CcTiling1; + DispatchLayoutInfo dispatchLayoutInfo; +}; + +#endif diff --git a/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp b/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp new file mode 100644 index 00000000..48cf519f --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/dispatch_normal_a2.cpp @@ -0,0 +1,100 @@ +#include "kernel_operator.h" +// #include "notify_dispatch_a2.h" +// #include "notify_dispatch_tiling_a2.h" +// #include "a2/a2.h" +#include "a2/cam_moe_distribute_dispatch_a2_layered.h" +#include "cam_moe_distribute_dispatch_tiling.h" + +#define TILING_KEY_FLOAT16 20 +#define TILING_KEY_BFLOAT16 21 +#define TILING_KEY_FLOAT 22 +#define TILING_KEY_INT 23 +#define TILING_KEY_A2_FLOAT16 120 +#define TILING_KEY_A2_BFLOAT16 121 +#define TILING_KEY_A2_FLOAT 122 +#define TILING_KEY_A2_INT 123 + +#define KERNEL_USE_WORKSPACE (1 * 1024 * 1024) + +using namespace AscendC; +using namespace MoeDistributeDispatchA2Impl; +using namespace Cam; + +extern "C" __global__ __aicore__ void dispatch_normal_a2( + GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR xActiveMask, GM_ADDR expertScales, GM_ADDR tokenServerIdx, + GM_ADDR tokenServerCnt, GM_ADDR epRankTokenCnt, GM_ADDR srcOffsetRankTokenIdx, GM_ADDR dstOffsetRankTokenIdx, + GM_ADDR recvX, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut, GM_ADDR epRecvCountOut, + GM_ADDR expandScalesOut, GM_ADDR dispatchWaitRecvCostStatsOut, GM_ADDR workspace, GM_ADDR tiling) +{ + printf("[dispatch_normal_a2] blockId: %d\n", GetBlockIdx()); + // REGISTER_TILING_DEFAULT(NotifyDispatchA2TilingData); + // GET_TILING_DATA_WITH_STRUCT(NotifyDispatchA2TilingData, tilingData, tiling); + REGISTER_TILING_DEFAULT(CamMoeDistributeDispatchA2TilingData); + GET_TILING_DATA_WITH_STRUCT(CamMoeDistributeDispatchA2TilingData, tilingData, tiling); + + // hcomm will set magic later in init + uint32_t magic = 1; + GM_ADDR commArgs = nullptr; + + // int localRank = tilingData.notifyDispatchInfoA2.localRankId; + // int localRankSize = tilingData.notifyDispatchInfoA2.localRankSize; + // int rank = tilingData.notifyDispatchInfoA2.rankId; + // int rankSize = tilingData.notifyDispatchInfoA2.rankSize; + // int64_t len = tilingData.notifyDispatchInfoA2.sendCount; + // int64_t numTokens = tilingData.notifyDispatchInfoA2.numTokens; + // int64_t topkNum = tilingData.notifyDispatchInfoA2.topkNum; + // int64_t numExperts = tilingData.notifyDispatchInfoA2.numExperts; + + // GM_ADDR sendDataInput = sendData; + // GM_ADDR tokenPerExpertDataInput = tokenPerExpertData; + // GM_ADDR sendDataOffsetOutput = sendDataOffset; + // GM_ADDR recvDataOutput = recvData; + // GM_ADDR tokenServerIdxOutput = tokenServerIdx; + // GM_ADDR tokensUniquePerServerOutput = tokensUniquePerServer; + // GM_ADDR epRankTokenCntOutput = epRankTokenCnt; + // GM_ADDR localEpTokenCntOutput = localEpTokenCnt; + // GM_ADDR srcOffsetRankTokenIdxOutput = srcOffsetRankTokenIdx; + // GM_ADDR dstOffsetRankTokenIdxOutput = dstOffsetRankTokenIdx; + // GM_ADDR offsetInnerOutput = offsetInner; + // GM_ADDR countOuterOutput = countOuter; + + // fill in unused args + uint32_t extraFlag = 0; + GM_ADDR scale = nullptr; + int root = 0; + int op = 0; + int cycleCount = 0; + int64_t scaleCount = 0; + GM_ADDR offset = nullptr; + int blockNum = GetBlockNum(); + + TPipe pipe; + if (TILING_KEY_IS(2100001000)) { + // NotifyDispatchA2 opKernel(rank, rankSize, extraFlag); + // opKernel.Init(KERNELS_ARGS_CALL_A2_ALL2ALL()); + // opKernel.Process(); + CamMoeDistributeDispatchA2Layered op; + op.Init(x, expertIds, scales, expertScales, tokenServerIdx, tokenServerCnt, epRankTokenCnt, + srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, recvX, dynamicScalesOut, expandIdxOut, expertTokenNumsOut, + epRecvCountOut, expandScalesOut, workspace, &pipe, tiling); + op.Process(); + } else if (TILING_KEY_IS(2000000000)) { + // NotifyDispatchA2 opKernel(rank, rankSize, extraFlag); + // opKernel.Init(KERNELS_ARGS_CALL_A2_ALL2ALL()); + // opKernel.Process(); + CamMoeDistributeDispatchA2Layered op; + op.Init(x, expertIds, scales, expertScales, tokenServerIdx, tokenServerCnt, epRankTokenCnt, + srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, recvX, dynamicScalesOut, expandIdxOut, expertTokenNumsOut, + epRecvCountOut, expandScalesOut, workspace, &pipe, tiling); + op.Process(); + } else if (TILING_KEY_IS(2000001000)) { + // NotifyDispatchA2 opKernel(rank, rankSize, extraFlag); + // opKernel.Init(KERNELS_ARGS_CALL_A2_ALL2ALL()); + // opKernel.Process(); + CamMoeDistributeDispatchA2Layered op; + op.Init(x, expertIds, scales, expertScales, tokenServerIdx, tokenServerCnt, epRankTokenCnt, + srcOffsetRankTokenIdx, dstOffsetRankTokenIdx, recvX, dynamicScalesOut, expandIdxOut, expertTokenNumsOut, + epRecvCountOut, expandScalesOut, workspace, &pipe, tiling); + op.Process(); + } +} diff --git a/csrc/deepep/ops2/op_kernel/moe_distribute_base.h b/csrc/deepep/ops2/op_kernel/moe_distribute_base.h new file mode 100644 index 00000000..b899a0e4 --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/moe_distribute_base.h @@ -0,0 +1,199 @@ +/*! + * \file moe_distribute_base.h + * \brief + */ + +#ifndef MOE_DISTRIBUTE_BASE_H +#define MOE_DISTRIBUTE_BASE_H + +/* system tick: 50MHz */ +#define CAL_US(tick) (((tick) * 2) / 100) + +/* performance macro */ +// #define USE_256_TO_1__ +#ifdef USE_256_TO_1__ +#pragma message("use 256 to 1") +#else +#define USE_FOR_OPT__ +#define DISPATCH_USE_WRITE_SHUFFLE__ +#define USE_TOKEN_COUNT_SPLIT__ +#define USE_ONE_CORE_WAIT__ + +#ifdef USE_ONE_CORE_WAIT__ +#pragma message("use one core wait") + +// #define USE_ONE_CORE_GETCUMSUM__ +#endif +#ifdef USE_FOR_OPT__ +#pragma message("use for optimization") +#define FOR_OPT_MAX_BS__ 64 +#define FOR_OPT_MAX_MOE_RANK__ 256 +#endif +// #define COMBINE_USE_DYNAMIC_QUANT +#define OPT_RANK_OFFSET 512 +#define USE_WRITE_SHUFFLE +#endif + +constexpr uint32_t LOCAL_NOTIFY_MAX_NUM = 64; +constexpr uint32_t LOCAL_STREAM_MAX_NUM = 19; +constexpr uint32_t AICPU_OP_NOTIFY_MAX_NUM = 2; +constexpr uint32_t AICPU_MAX_RANK_NUM = 128 * 1024; + +struct HcclSignalInfo { + uint64_t resId; + uint64_t addr; + uint32_t devId; + uint32_t tsId; + uint32_t rankId; + uint32_t flag; +}; + +struct ListCommon { + uint64_t nextHost; + uint64_t preHost; + uint64_t nextDevice; + uint64_t preDevice; +}; + +struct HcclStreamInfo { + int32_t streamIds; + uint32_t sqIds; + uint32_t cqIds; + uint32_t logicCqids; +}; + +struct LocalResInfoV2 { + uint32_t streamNum; + uint32_t signalNum; + HcclSignalInfo localSignals[LOCAL_NOTIFY_MAX_NUM]; + HcclStreamInfo streamInfo[LOCAL_STREAM_MAX_NUM]; + HcclStreamInfo mainStreamInfo; + HcclSignalInfo aicpuOpNotify[AICPU_OP_NOTIFY_MAX_NUM]; + ListCommon nextTagRes; // HccltagLocalResV2 +}; + +enum class rtFloatOverflowMode_t { + RT_OVERFLOW_MODE_SATURATION = 0, + RT_OVERFLOW_MODE_INFNAN, + RT_OVERFLOW_MODE_UNDEF, +}; + +struct AlgoTopoInfo { + uint32_t userRank; // RankID + uint32_t userRankSize; // Rank Number + int32_t deviceLogicId; + bool isSingleMeshAggregation; + uint32_t deviceNumPerAggregation; + uint32_t superPodNum; + uint32_t devicePhyId; + uint32_t topoType; // TopoType + uint32_t deviceType; + uint32_t serverNum; + uint32_t meshAggregationRankSize; + uint32_t multiModuleDiffDeviceNumMode; + uint32_t multiSuperPodDiffServerNumMode; + uint32_t realUserRank; + bool isDiffDeviceModule; + bool isDiffDeviceType; + uint32_t gcdDeviceNumPerAggregation; + uint32_t moduleNum; + uint32_t isUsedRdmaRankPairNum; + uint64_t isUsedRdmaRankPair; + uint32_t pairLinkCounterNum; + uint64_t pairLinkCounter; + uint32_t nicNum; + uint64_t nicList; + uint64_t complanRankLength; + uint64_t complanRank; + uint64_t bridgeRankNum; + uint64_t bridgeRank; + uint64_t serverAndsuperPodRankLength; + uint64_t serverAndsuperPodRank; +}; + +struct HcclOpConfig { + uint8_t deterministic; + uint8_t retryEnable; + uint8_t highPerfEnable; + uint8_t padding[5]; + uint8_t linkTimeOut[8]; + uint64_t notifyWaitTime; + uint32_t retryHoldTime; + uint32_t retryIntervalTime; + bool interHccsDisable = false; + rtFloatOverflowMode_t floatOverflowMode = rtFloatOverflowMode_t::RT_OVERFLOW_MODE_UNDEF; + uint32_t multiQpThreshold = 512; +}; + +struct HcclMC2WorkSpace { + uint64_t workSpace; + uint64_t workSpaceSize; +}; + +struct RemoteResPtr { + uint64_t nextHostPtr; + uint64_t nextDevicePtr; +}; + +struct HDCommunicateParams { + uint64_t hostAddr{0}; + uint64_t deviceAddr{0}; + uint64_t readCacheAddr{0}; + uint32_t devMemSize{0}; + uint32_t buffLen{0}; + uint32_t flag{0}; +}; + +struct HcclRankRelationResV2 { + uint32_t remoteUsrRankId; + uint32_t remoteWorldRank; + uint64_t windowsIn; + uint64_t windowsOut; + uint64_t windowsExp; + ListCommon nextTagRes; +}; + +struct HcclOpResParam { + // local resource + HcclMC2WorkSpace mc2WorkSpace; + uint32_t localUsrRankId; // usrrankid + uint32_t rankSize; + uint64_t winSize; + uint64_t localWindowsIn; + uint64_t localWindowsOut; + char hcomId[128]; + // aicore detect remote window + uint64_t winExpSize; + uint64_t localWindowsExp; + uint32_t rWinStart; + uint32_t rWinOffset; + uint64_t version; + LocalResInfoV2 localRes; + AlgoTopoInfo topoInfo; + + // config parameters + HcclOpConfig config; + uint64_t hostStateInfo; + uint64_t aicpuStateInfo; + uint64_t lockAddr; + uint32_t rsv[16]; + uint32_t notifysize; + uint32_t remoteResNum; + RemoteResPtr remoteRes[AICPU_MAX_RANK_NUM]; + + // communicate retry + HDCommunicateParams kfcControlTransferH2DParams; + HDCommunicateParams kfcStatusTransferD2HParams; + uint64_t tinyMem; // for all2all + uint64_t tinyMemSize; + // zero-copy + uint64_t zeroCopyHeadPtr; + uint64_t zeroCopyTailPtr; + uint64_t zeroCopyRingBuffer; + uint64_t zeroCopyIpcPtrs[16]; + uint32_t zeroCopyDevicePhyId[16]; + + bool utraceStatusFlag; +}; + +#endif // MOE_DISTRIBUTE_BASE_H diff --git a/csrc/deepep/ops2/op_kernel/moe_distribute_combine_a2.cpp b/csrc/deepep/ops2/op_kernel/moe_distribute_combine_a2.cpp new file mode 100644 index 00000000..ca3cb3cf --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/moe_distribute_combine_a2.cpp @@ -0,0 +1,39 @@ +#include "kernel_operator.h" +#include "moe_distribute_combine_a2_tiling.h" +#include "moe_distribute_combine_a2.h" +#include "moe_distribute_combine_a2_layered.h" +#include + +using namespace AscendC; +using namespace MoeDistributeCombineA2Impl; +extern "C" __global__ __aicore__ void moe_distribute_combine_a2( + GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, GM_ADDR scales, GM_ADDR tpSendCount, + GM_ADDR xActiveMask, GM_ADDR activationScale, GM_ADDR weightScale, GM_ADDR groupList, GM_ADDR expandScales, + GM_ADDR offsetInner, GM_ADDR offsetOuter, GM_ADDR countOuter, GM_ADDR XOut, GM_ADDR workspaceGM, GM_ADDR tilingGM) + +{ + printf("===========combine_a2=============\n"); + REGISTER_TILING_DEFAULT(MoeDistributeCombineA2TilingData); + TPipe pipe; +#if (ORIG_DTYPE_EXPAND_X == DT_BF16 || ORIG_DTYPE_EXPAND_X == DT_FLOAT16) + if (TILING_KEY_IS(2000)) { + GET_TILING_DATA_WITH_STRUCT(MoeDistributeCombineA2TilingData, tilingData, tilingGM); + auto tiling = (__gm__ MoeDistributeCombineA2TilingData *)tilingGM; + __gm__ void *mc2InitTiling = (__gm__ void *)(&(tiling->mc2InitTiling)); + __gm__ void *mc2CcTiling = (__gm__ void *)(&(tiling->mc2CcTiling)); + MoeDistributeCombineA2 op; + op.Init(expandX, expertIds, expandIdx, epSendCount, scales, XOut, workspaceGM, &pipe, &tilingData, + mc2InitTiling, mc2CcTiling); + op.Process(); + } else if (TILING_KEY_IS(3000)) { + GET_TILING_DATA_WITH_STRUCT(MoeDistributeCombineA2TilingData, tilingData, tilingGM); + auto tiling = (__gm__ MoeDistributeCombineA2TilingData *)tilingGM; + __gm__ void *mc2InitTiling = (__gm__ void *)(&(tiling->mc2InitTiling)); + __gm__ void *mc2CcTiling = (__gm__ void *)(&(tiling->mc2CcTiling)); + MoeDistributeCombineA2Layered op; + op.Init(expandX, expandIdx, epSendCount, offsetInner, offsetOuter, countOuter, expandScales, XOut, workspaceGM, + &pipe, &tilingData, mc2InitTiling, mc2CcTiling); + op.Process(); + } +#endif +} diff --git a/csrc/deepep/ops2/op_kernel/moe_distribute_combine_a2.h b/csrc/deepep/ops2/op_kernel/moe_distribute_combine_a2.h new file mode 100644 index 00000000..0d46c237 --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/moe_distribute_combine_a2.h @@ -0,0 +1,550 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_distribute_combine_a2.h + * \brief + */ +#ifndef MOE_DISTRIBUTE_COMBINE_A2_H +#define MOE_DISTRIBUTE_COMBINE_A2_H +#include "kernel_operator.h" +#include "kernel_tiling/kernel_tiling.h" +#include "moe_distribute_combine_a2_tiling.h" +#include "moe_distribute_base.h" +namespace { +constexpr uint8_t BUFFER_NUM = 2; // 多buf +constexpr uint32_t STATE_OFFSET = 512; // 状态空间偏移地址 +constexpr uint32_t STATE_SPACE_SIZE = 1024 * 1024; // 1M +constexpr uint32_t UB_ALIGN = 32; // UB按32字节对齐 +constexpr uint32_t SELF_STATE_OFFSET = 512 * 1024; // 本卡状态空间偏移地址 +constexpr uint32_t BATCH_WRITE_ITEM_OFFSET = 8 * 1024; // batchWriteInfo结构体地址相对于windowOut最后1M的偏移 +constexpr uint32_t BATCH_WRITE_ITEM_SIZE = 32; +constexpr uint32_t BLOCK_SIZE = 32; +constexpr uint32_t B32_PER_BLOCK = 8; +constexpr uint32_t B64_PER_BLOCK = 4; +constexpr uint32_t SKIP_OFFSET = 32; +constexpr uint32_t FLAG_VALUE = 0xFFFFFFFF; +constexpr uint64_t MB_SIZE = 1024 * 1024; +template +__aicore__ inline void SyncFunc() +{ + int32_t eventID = static_cast(GetTPipePtr()->FetchEventID(event)); + AscendC::SetFlag(eventID); + AscendC::WaitFlag(eventID); +} +template +inline __aicore__ T RoundUp(const T val, const T align) +{ + if (align == 0 || val + align - 1 < val) { + return val; + } + return (val + align - 1) / align * align; +} + +struct TaskInfo { + uint32_t startTaskId; + uint32_t endTaskId; + uint32_t taskNum; + + __aicore__ inline TaskInfo() {} + __aicore__ inline void SplitCore(uint32_t taskNumTotal, uint32_t aivNum, uint32_t aivId) + { + if (aivNum == 0) { + startTaskId = 0; + endTaskId = 0; + taskNum = 0; + return; + } + + uint32_t formerNum = taskNumTotal / aivNum; + uint32_t tailNum = taskNumTotal % aivNum; + startTaskId = formerNum * aivId; + if (aivId < tailNum) { + formerNum++; + startTaskId += aivId; + } else { + startTaskId += tailNum; + } + taskNum = formerNum; + endTaskId = startTaskId + taskNum; + } +}; + +} // namespace +namespace MoeDistributeCombineA2Impl { +#define TemplateMC2TypeA2Class typename ExpandXType, typename ExpandIdxType +#define TemplateMC2TypeA2Func ExpandXType, ExpandIdxType +using namespace AscendC; +template +class MoeDistributeCombineA2 +{ +public: + __aicore__ inline MoeDistributeCombineA2(){}; + __aicore__ inline void Init(GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR sendCount, + GM_ADDR scales, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, + const MoeDistributeCombineA2TilingData *tilingData, __gm__ void *mc2InitTiling, + __gm__ void *mc2CcTiling); + __aicore__ inline void Process(); + +private: + __aicore__ inline void LocalWindowCopy(); + __aicore__ inline void AlltoAllDispatch(); + __aicore__ inline void BuffInit(); + __aicore__ inline void SplitCoreCal(); + __aicore__ inline void Preload(); + __aicore__ inline void WaitDispatch(); + TPipe *tpipe_{nullptr}; + GlobalTensor expandXGlobal_; + GlobalTensor expertIdsGlobal_; + GlobalTensor expandIdxGlobal_; + GlobalTensor sendCountGlobal_; + GlobalTensor expandScalesGlobal_; + GlobalTensor expandOutGlobal_; + GlobalTensor rankWindow_; // 用于存对端window的变量 + GlobalTensor localOutWindow_; + GlobalTensor localInWindow_; + GlobalTensor windowInstatusTensor_; + GlobalTensor bufferIdGlobal_; // win区状态位置拷入相关参数 + GlobalTensor workspaceGlobal_; // 存储batchWriteInfo结构体信息 + GlobalTensor workspaceGlobal32_; // 存储batchWriteInfo结构体信息 + GlobalTensor flagGlobal_; + LocalTensor batchWriteItemLocalB64; + LocalTensor batchWriteItemLocalB32; + LocalTensor recvCountLocal_; + LocalTensor expertWindowOffsetLocal_; + LocalTensor rowTmpFloatLocal_; + LocalTensor mulBufLocal_; + LocalTensor sumFloatLocal_; + LocalTensor expertIdsLocal_; + LocalTensor expandScalesLocal_; + LocalTensor indexCountsLocal_; + LocalTensor tmpUb_; + LocalTensor statusTensor_; + GM_ADDR windowInGM_; + GM_ADDR windowOutGM_; + GM_ADDR expandXGM_; + GM_ADDR expertIdsGM_; + GM_ADDR expandIdxGM_; + GM_ADDR sendCountGM_; + GM_ADDR scalesGM_; + GM_ADDR XOutGM_; + // tiling侧已确保数据上限,相乘不会越界,因此统一采用uint32_t进行处理 + uint32_t axisBS_{0}; + uint32_t axisH_{0}; + uint32_t axisK_{0}; // topK + uint32_t aivNum_{0}; + uint32_t worldSize_{0}; + uint32_t rankId_{0}; + uint32_t coreIdx_{0}; // aiv id + uint32_t sharedExpertRankNum_{0}; // 共享专家卡数 + uint32_t moeExpertNum_{0}; // moe专家数, 等于worldSize_ - 共享专家卡数 + uint32_t localMoeExpertNum_{0}; // 每张卡的专家数 + uint32_t expandXRows_; + uint64_t rankSizeOnWin_{0}; + uint64_t dataOffsetOnWin_{0}; + uint64_t stateOffsetOnWin_{0}; + uint32_t axisHFloatSize_{0}; + uint32_t axisHExpandXTypeSize_{0}; + uint32_t bsKAlign_{0}; + uint32_t startRankId_{0}; + uint32_t endRankId_{0}; + uint32_t sendRankNum_{0}; + uint32_t halfWinSize_{0}; + uint32_t dataSpaceSize_{0}; + uint32_t bufferId_{0}; + uint32_t tokenNumPerCore_{0}; + uint32_t tokenIndex_{0}; + TQueBind moeQueue_; + TQue moeSumQueue_; + TBuf<> expertIdsBuf_; + TBuf<> expandScalesBuf_; + TBuf<> rowTmpFloatBuf_; + TBuf<> sumFloatBuf_; + TBuf<> mulBuf_; + TBuf<> sendCountBuf_; + TBuf<> indexCountsBuf_; + TBuf<> tokenBuf_; + TBuf<> statusBuf_; + TBuf<> batchWriteItemBuf_; + TBuf<> recvCountBuf_; + TBuf<> expertWindowOffsetBuf_; + + TaskInfo taskInfo_; + + GlobalTensor expertRecvCountGlobal_; + GlobalTensor expertWindowOffsetGlobal_; + + Hccl hccl_; + __gm__ HcclOpResParam *winContext_{nullptr}; +}; +template +__aicore__ inline void MoeDistributeCombineA2::Init( + GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR sendCount, GM_ADDR scales, GM_ADDR XOut, + GM_ADDR workspaceGM, TPipe *pipe, const MoeDistributeCombineA2TilingData *tilingData, __gm__ void *mc2InitTiling, + __gm__ void *mc2CcTiling) +{ + tpipe_ = pipe; + expandXGM_ = expandX; + expertIdsGM_ = expertIds; + expandIdxGM_ = expandIdx; + sendCountGM_ = sendCount; + scalesGM_ = scales; + XOutGM_ = XOut; + rankId_ = tilingData->moeDistributeCombineInfo.epRankId; + axisBS_ = tilingData->moeDistributeCombineInfo.bs; + axisH_ = tilingData->moeDistributeCombineInfo.h; + axisK_ = tilingData->moeDistributeCombineInfo.k; + aivNum_ = tilingData->moeDistributeCombineInfo.aivNum; + moeExpertNum_ = tilingData->moeDistributeCombineInfo.moeExpertNum; + worldSize_ = tilingData->moeDistributeCombineInfo.epWorldSize; + auto contextGM = AscendC::GetHcclContext(); + winContext_ = (__gm__ HcclOpResParam *)contextGM; + hccl_.Init(contextGM, mc2InitTiling); + hccl_.SetCcTiling(mc2CcTiling); + coreIdx_ = GetBlockIdx(); + PRINTF("[Init] combine_a2, coreId:%d \n", coreIdx_); + + /* + halfWinSize_ = winContext_->winSize / 2; + dataSpaceSize_ = halfWinSize_ - STATE_SPACE_SIZE; + windowInGM_ = hccl_.GetWindowsInAddr(rankId_); + bufferIdGlobal_.SetGlobalBuffer((__gm__ uint32_t *)(windowInGM_ + dataSpaceSize_)); + bufferId_ = bufferIdGlobal_.GetValue(0); + windowInGM_ = windowInGM_ + halfWinSize_ * bufferId_; + windowOutGM_ = hccl_.GetWindowsOutAddr(rankId_) + halfWinSize_ * bufferId_; + + windowInstatusTensor_.SetGlobalBuffer((__gm__ uint32_t *)(windowInGM_)); + expandXGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)expandX); + expertIdsGlobal_.SetGlobalBuffer((__gm__ ExpandIdxType *)expertIds); + expandIdxGlobal_.SetGlobalBuffer((__gm__ ExpandIdxType *)expandIdx); + sendCountGlobal_.SetGlobalBuffer((__gm__ int32_t *)sendCount); + expandScalesGlobal_.SetGlobalBuffer((__gm__ float *)scales); + expandOutGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)XOut); + workspaceGlobal_.SetGlobalBuffer((__gm__ uint64_t *)(windowOutGM_ + dataSpaceSize_ + BATCH_WRITE_ITEM_OFFSET)); + workspaceGlobal32_.SetGlobalBuffer((__gm__ uint32_t *)(windowOutGM_ + dataSpaceSize_ + BATCH_WRITE_ITEM_OFFSET)); + + expertRecvCountGlobal_.SetGlobalBuffer((__gm__ uint32_t *)workspaceGM); + expertWindowOffsetGlobal_.SetGlobalBuffer((__gm__ uint32_t *)(workspaceGM + moeExpertNum_ * sizeof(uint32_t))); + + localMoeExpertNum_ = moeExpertNum_ / worldSize_; + expandXRows_ = localMoeExpertNum_ * axisBS_ * worldSize_; + rankSizeOnWin_ = dataSpaceSize_ / worldSize_ / BLOCK_SIZE * BLOCK_SIZE; + dataOffsetOnWin_ = rankId_ * rankSizeOnWin_; + stateOffsetOnWin_ = dataSpaceSize_ + rankId_ * STATE_OFFSET; + axisHFloatSize_ = axisH_ * sizeof(float); + axisHExpandXTypeSize_ = axisH_ * sizeof(ExpandXType); + bsKAlign_ = RoundUp(axisBS_ * axisK_, (uint32_t)8); + + uint64_t stateSizeMaxSize = 2 * STATE_SPACE_SIZE; // 2: 实际上是(DATA_OFFSET+SKIP_OFFSET+sizeof(uint32)) + + STATE_SPACE_SIZE,近似计算使用2 * STATE_SPACE_SIZE uint64_t winSizeMin = (axisBS_ * worldSize_ * (localMoeExpertNum_ + > axisK_ ? axisK_ : localMoeExpertNum_) * axisH_ * sizeof(uint16_t) + stateSizeMaxSize) * BUFFER_NUM; // + 考虑负载极其不均衡时,HCCL BUFFSIZE需要开的大小 + assert(winContext_->winSize >= winSizeMin, "The HCCL_BUFFSIZE is %lluMB, the min value should be %lluMB. \ + epWorldSize:%u, epRankId:%u, moeExpertNum:%u, globalBs:%u, bs:%u, k:%u, h:%u, aivNum:%u, \ + totalUbSize:%llu\n", + winContext_->winSize / MB_SIZE, winSizeMin / MB_SIZE, + tilingData->moeDistributeCombineInfo.epWorldSize, tilingData->moeDistributeCombineInfo.epRankId, + tilingData->moeDistributeCombineInfo.moeExpertNum, tilingData->moeDistributeCombineInfo.globalBs, + tilingData->moeDistributeCombineInfo.bs, tilingData->moeDistributeCombineInfo.k, + tilingData->moeDistributeCombineInfo.h, tilingData->moeDistributeCombineInfo.aivNum, + tilingData->moeDistributeCombineInfo.totalUbSize + ); + + BuffInit(); + SplitCoreCal(); + */ +} +template +__aicore__ inline void MoeDistributeCombineA2::BuffInit() +{ + tpipe_->InitBuffer(moeQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 7168 * 2 * 2 = 28672 + tpipe_->InitBuffer(statusBuf_, worldSize_ * UB_ALIGN); + tpipe_->InitBuffer(expertIdsBuf_, axisBS_ * axisK_ * sizeof(int32_t)); // 32 * 8 * 4 = 1024 + tpipe_->InitBuffer(expandScalesBuf_, axisBS_ * axisK_ * sizeof(float)); // 32 * 8 * 4 = 1024 + tpipe_->InitBuffer(tokenBuf_, axisHExpandXTypeSize_); // 7168 * 2 = 14336 + tpipe_->InitBuffer(rowTmpFloatBuf_, axisHFloatSize_); // 7168 * 4 = 28672 + tpipe_->InitBuffer(mulBuf_, axisHFloatSize_); // 7168 * 4 = 28672 + tpipe_->InitBuffer(sumFloatBuf_, axisHFloatSize_); // 7168 * 4 = 28672 + tpipe_->InitBuffer(sendCountBuf_, RoundUp(moeExpertNum_, B32_PER_BLOCK) * sizeof(int32_t)); + tpipe_->InitBuffer(indexCountsBuf_, axisBS_ * axisK_ * sizeof(int32_t)); // 32 * 8 * 4 = 1024 + tpipe_->InitBuffer(moeSumQueue_, BUFFER_NUM, axisHExpandXTypeSize_); + tpipe_->InitBuffer(batchWriteItemBuf_, BATCH_WRITE_ITEM_SIZE * worldSize_); + batchWriteItemLocalB64 = batchWriteItemBuf_.Get(); + batchWriteItemLocalB32 = batchWriteItemLocalB64.template ReinterpretCast(); +} +template +__aicore__ inline void MoeDistributeCombineA2::SplitCoreCal() +{ + // 对worldSize按卡分核,得到每个核上处理的卡的数量 + sendRankNum_ = worldSize_ / aivNum_; + uint32_t remainderRankNum = worldSize_ % aivNum_; + startRankId_ = sendRankNum_ * coreIdx_; + if (coreIdx_ < remainderRankNum) { + sendRankNum_++; + startRankId_ += coreIdx_; + } else { + startRankId_ += remainderRankNum; + } + endRankId_ = startRankId_ + sendRankNum_; +} +template +__aicore__ inline void MoeDistributeCombineA2::AlltoAllDispatch() +{ + if (sendRankNum_ == 0) { + SyncAll(); + return; + } + LocalTensor sendCountLocal = sendCountBuf_.Get(); + DataCopy(sendCountLocal, sendCountGlobal_, RoundUp(moeExpertNum_, B32_PER_BLOCK)); + SyncFunc(); + for (uint32_t dstRankId = startRankId_; dstRankId < endRankId_; ++dstRankId) { + localOutWindow_.SetGlobalBuffer((__gm__ ExpandXType *)(windowOutGM_ + dstRankId * rankSizeOnWin_)); + uint32_t rankTokenNum = 0; + for (uint32_t expertId = 0; expertId < localMoeExpertNum_; ++expertId) { + uint32_t preCount = 0; + if (expertId != 0 || dstRankId != 0) { + preCount = static_cast(sendCountLocal.GetValue(expertId * worldSize_ + dstRankId - 1)); + } + uint32_t startTokenAddr = preCount * axisH_; + uint32_t tokenNum = sendCountLocal(expertId * worldSize_ + dstRankId) - preCount; + for (uint32_t tokenId = 0; tokenId < tokenNum; ++tokenId) { + LocalTensor InUb = moeQueue_.AllocTensor(); + DataCopy(InUb, expandXGlobal_[startTokenAddr], axisH_); + moeQueue_.EnQue(InUb); + LocalTensor OutUb = moeQueue_.DeQue(); + DataCopy(localOutWindow_[rankTokenNum * axisH_], OutUb, axisH_); + moeQueue_.FreeTensor(OutUb); + startTokenAddr += axisH_; + rankTokenNum++; + } + } + flagGlobal_.SetGlobalBuffer( + (__gm__ uint32_t *)(localOutWindow_.GetPhyAddr(rankTokenNum * axisH_) + SKIP_OFFSET / sizeof(ExpandXType))); + flagGlobal_(0) = FLAG_VALUE; + uint32_t rankIdOffset = dstRankId - startRankId_; + batchWriteItemLocalB64(rankIdOffset * 4) = (uint64_t)(localOutWindow_.GetPhyAddr()); + batchWriteItemLocalB64(rankIdOffset * 4 + 1) = + (uint64_t)(hccl_.GetWindowsInAddr(dstRankId) + halfWinSize_ * bufferId_ + dataOffsetOnWin_); + batchWriteItemLocalB64(rankIdOffset * 4 + 2) = rankTokenNum * axisH_ + SKIP_OFFSET / sizeof(ExpandXType) + 2; + batchWriteItemLocalB32(rankIdOffset * 8 + 6) = HcclDataType::HCCL_DATA_TYPE_FP16; + batchWriteItemLocalB32(rankIdOffset * 8 + 7) = dstRankId; + DataCacheCleanAndInvalid( + flagGlobal_); + } + SyncFunc(); + DataCopy(workspaceGlobal_[startRankId_ * 4], batchWriteItemLocalB64, sendRankNum_ * 4); + SyncFunc(); + SyncAll(); + if ASCEND_IS_AIV { + if (coreIdx_ == 0) { + HcclHandle handleId = hccl_.BatchWrite((GM_ADDR)(workspaceGlobal_.GetPhyAddr()), worldSize_); + bufferIdGlobal_(0) = bufferId_ ^ 1; + } + if (rankId_ >= startRankId_ && rankId_ < endRankId_) { + localOutWindow_.SetGlobalBuffer((__gm__ ExpandXType *)(windowOutGM_ + dataOffsetOnWin_)); + localInWindow_.SetGlobalBuffer((__gm__ ExpandXType *)(windowInGM_ + dataOffsetOnWin_)); + uint32_t rankIdOffset = rankId_ - startRankId_; + uint64_t rankTokenNum = + (batchWriteItemLocalB64(rankIdOffset * 4 + 2) - SKIP_OFFSET / sizeof(ExpandXType) - 2) / axisH_; + for (uint32_t tokenId = 0; tokenId < rankTokenNum; ++tokenId) { + LocalTensor InUb = moeQueue_.AllocTensor(); + DataCopy(InUb, localOutWindow_[tokenId * axisH_], axisH_); + moeQueue_.EnQue(InUb); + LocalTensor OutUb = moeQueue_.DeQue(); + DataCopy(localInWindow_[tokenId * axisH_], OutUb, axisH_); + moeQueue_.FreeTensor(OutUb); + } + flagGlobal_.SetGlobalBuffer((__gm__ uint32_t *)localInWindow_.GetPhyAddr( + rankTokenNum * axisH_ + SKIP_OFFSET / sizeof(ExpandXType))); + flagGlobal_(0) = FLAG_VALUE; + DataCacheCleanAndInvalid( + flagGlobal_); + } + } +} +template +__aicore__ inline void MoeDistributeCombineA2::Preload() +{ + tpipe_->InitBuffer(recvCountBuf_, sizeof(uint32_t) * moeExpertNum_); + tpipe_->InitBuffer(expertWindowOffsetBuf_, sizeof(uint32_t) * moeExpertNum_); + recvCountLocal_ = recvCountBuf_.Get(); + expertWindowOffsetLocal_ = expertWindowOffsetBuf_.Get(); + expertIdsLocal_ = expertIdsBuf_.Get(); + DataCopy(expertIdsLocal_, expertIdsGlobal_, bsKAlign_); + Duplicate(recvCountLocal_, (uint32_t)0, moeExpertNum_); + Duplicate(expertWindowOffsetLocal_, (uint32_t)0, moeExpertNum_); + + SyncFunc(); + + if (coreIdx_ == aivNum_ - 1) { + DataCopyPad(expertRecvCountGlobal_, recvCountLocal_, + {1, static_cast(moeExpertNum_ * sizeof(uint32_t)), 0, 0, 0}); + } + + SyncAll(); + + taskInfo_.SplitCore(axisBS_ * axisK_, aivNum_, coreIdx_); + for (uint32_t i = taskInfo_.startTaskId; i < taskInfo_.endTaskId; ++i) { + uint32_t expId = expertIdsLocal_.GetValue(i); + recvCountLocal_(expId) += 1; + } + SyncFunc(); + + SetAtomicAdd(); + DataCopyPad(expertRecvCountGlobal_, recvCountLocal_, + {1, static_cast(moeExpertNum_ * sizeof(uint32_t)), 0, 0, 0}); + SetAtomicNone(); + + SyncAll(); + + DataCopyPad(recvCountLocal_, expertRecvCountGlobal_, + {1, static_cast(moeExpertNum_ * sizeof(uint32_t)), 0, 0, 0}, {false, 0, 0, 0}); + + SyncFunc(); + + taskInfo_.SplitCore(moeExpertNum_ / localMoeExpertNum_, aivNum_, coreIdx_); + for (uint32_t groupIdx = taskInfo_.startTaskId; groupIdx < taskInfo_.endTaskId; ++groupIdx) { + uint32_t start = groupIdx * localMoeExpertNum_; + uint32_t end = start + localMoeExpertNum_; + uint32_t prefixSum = 0; + for (uint32_t i = start; i < end; ++i) { + expertWindowOffsetLocal_(i - start) = prefixSum; + prefixSum += recvCountLocal_.GetValue(i); + } + SyncFunc(); + DataCopyPad(expertWindowOffsetGlobal_[start], expertWindowOffsetLocal_, + {1, static_cast(localMoeExpertNum_ * sizeof(uint32_t)), 0, 0, 0}); + SyncFunc(); + } + SyncAll(); + + DataCopyPad(expertWindowOffsetLocal_, expertWindowOffsetGlobal_, + {1, static_cast(moeExpertNum_ * sizeof(uint32_t)), 0, 0, 0}, {false, 0, 0, 0}); + + tokenNumPerCore_ = axisBS_ / aivNum_; + uint32_t undoTokenNum = axisBS_ % aivNum_; + tokenIndex_ = 0; + if (coreIdx_ < undoTokenNum) { + tokenNumPerCore_ = tokenNumPerCore_ + 1; + tokenIndex_ = coreIdx_ * tokenNumPerCore_; + } else { + tokenIndex_ = (undoTokenNum + coreIdx_ * tokenNumPerCore_); + } + if (tokenNumPerCore_ == 0) { + return; + } + rowTmpFloatLocal_ = rowTmpFloatBuf_.Get(); + mulBufLocal_ = mulBuf_.Get(); + sumFloatLocal_ = sumFloatBuf_.Get(); + expandScalesLocal_ = expandScalesBuf_.Get(); + indexCountsLocal_ = indexCountsBuf_.Get(); + DataCopy(indexCountsLocal_, expandIdxGlobal_, bsKAlign_); + DataCopy(expandScalesLocal_, expandScalesGlobal_, bsKAlign_); +} + +template +__aicore__ inline void MoeDistributeCombineA2::WaitDispatch() +{ + if (startRankId_ >= worldSize_) { + SyncAll(); + return; + } + SyncFunc(); + for (uint32_t waitFlagNum = 0; waitFlagNum < sendRankNum_;) { + waitFlagNum = 0; + for (uint32_t rankId = startRankId_; rankId < endRankId_; ++rankId) { + uint32_t tokenIdx = (rankId + 1) * localMoeExpertNum_ - 1; + GM_ADDR wAddr = windowInGM_ + rankSizeOnWin_ * rankId + SKIP_OFFSET + + (recvCountLocal_(tokenIdx) + expertWindowOffsetLocal_(tokenIdx)) * axisHExpandXTypeSize_; + flagGlobal_.SetGlobalBuffer((__gm__ uint32_t *)wAddr); + DataCacheCleanAndInvalid( + flagGlobal_); + uint32_t flag = flagGlobal_(0); + if (flag == FLAG_VALUE) { + waitFlagNum++; + } + } + } + for (uint32_t rankId = startRankId_; rankId < endRankId_; ++rankId) { + uint32_t tokenIdx = (rankId + 1) * localMoeExpertNum_ - 1; + GM_ADDR wAddr = windowInGM_ + rankSizeOnWin_ * rankId + SKIP_OFFSET + + (recvCountLocal_(tokenIdx) + expertWindowOffsetLocal_(tokenIdx)) * axisHExpandXTypeSize_; + flagGlobal_.SetGlobalBuffer((__gm__ uint32_t *)wAddr); + flagGlobal_(0) = 0; + } + SyncAll(); +} + +template +__aicore__ inline void MoeDistributeCombineA2::Process() +{ + PRINTF("[Process] combine_a2, coreId:%d \n", coreIdx_); + SyncAll(); + hccl_.Finalize(); + + /* + if ASCEND_IS_AIV { + AlltoAllDispatch(); + Preload(); + WaitDispatch(); + LocalWindowCopy(); + hccl_.Finalize(); + } + */ +} +template +__aicore__ inline void MoeDistributeCombineA2::LocalWindowCopy() +{ + if (tokenNumPerCore_ == 0) { + return; + } + // step 4 & step 5 + GM_ADDR wAddr; + int32_t expId = 0; + float scaleVal = 0.0; + for (uint32_t i = 0; i < tokenNumPerCore_; i++) { + uint32_t index = (tokenIndex_ + i) * axisK_; + Duplicate(sumFloatLocal_, 0.0f, axisH_); + for (uint32_t j = 0; j < axisK_; j++) { + expId = expertIdsLocal_.GetValue(index); + scaleVal = expandScalesLocal_.GetValue(index); + uint32_t rank = expId / localMoeExpertNum_; + wAddr = (__gm__ uint8_t *)(windowInGM_) + rankSizeOnWin_ * rank + + expertWindowOffsetLocal_.GetValue(expId) * axisHExpandXTypeSize_ + + indexCountsLocal_.GetValue(index) * axisHExpandXTypeSize_; + // copy experts from window + rankWindow_.SetGlobalBuffer((__gm__ ExpandXType *)wAddr); + tmpUb_ = moeSumQueue_.AllocTensor(); + DataCopy(tmpUb_, rankWindow_, axisH_); + moeSumQueue_.EnQue(tmpUb_); + LocalTensor tmpOtherUb_ = moeSumQueue_.DeQue(); + // cast before muls + Cast(rowTmpFloatLocal_, tmpOtherUb_, AscendC::RoundMode::CAST_NONE, axisH_); + PipeBarrier(); + // muls expert and scaleVal + AscendC::Muls(mulBufLocal_, rowTmpFloatLocal_, scaleVal, axisH_); + PipeBarrier(); + // add mulBufLocal to sumFloatBufLocal + AscendC::Add(sumFloatLocal_, sumFloatLocal_, mulBufLocal_, axisH_); + index++; + moeSumQueue_.FreeTensor(tmpOtherUb_); + } + // 结果搬出 + PipeBarrier(); + LocalTensor sumBufLocal_ = tokenBuf_.Get(); + SyncFunc(); + Cast(sumBufLocal_, sumFloatLocal_, AscendC::RoundMode::CAST_RINT, axisH_); + SyncFunc(); + DataCopy(expandOutGlobal_[(tokenIndex_ + i) * axisH_], sumBufLocal_, axisH_); + PipeBarrier(); + } +} +} // namespace MoeDistributeCombineA2Impl +#endif // MOE_DISTRIBUTE_COMBINE_A2_H diff --git a/csrc/deepep/ops2/op_kernel/moe_distribute_combine_a2_layered.h b/csrc/deepep/ops2/op_kernel/moe_distribute_combine_a2_layered.h new file mode 100644 index 00000000..da7ef8c7 --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/moe_distribute_combine_a2_layered.h @@ -0,0 +1,770 @@ +/** + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file moe_distribute_combine_a2_layered.h + * \brief + */ +#ifndef MOE_DISTRIBUTE_COMBINE_A2_LAYERED_H +#define MOE_DISTRIBUTE_COMBINE_A2_LAYERED_H +#include "kernel_operator.h" +#include "kernel_tiling/kernel_tiling.h" +#include "moe_distribute_combine_a2_tiling.h" +#include "moe_distribute_base.h" + +namespace MoeDistributeCombineA2Impl { + +#define TemplateMC2TypeA2layeredClass typename ExpandXType, typename ExpandIdxType +#define TemplateMC2TypeA2layeredFunc ExpandXType, ExpandIdxType +using namespace AscendC; +template +class MoeDistributeCombineA2Layered +{ +public: + constexpr static uint32_t BUFFER_NUM = 2U; // 多buf + constexpr static uint32_t STATE_OFFSET = 512U; // 状态空间偏移地址 + constexpr static uint32_t STATE_SPACE_SIZE = 1024U * 1024U; // 1M + constexpr static uint32_t UB_ALIGN = 32U; // UB按32字节对齐 + constexpr static uint32_t SELF_STATE_OFFSET = 512U * 1024U; // 本卡状态空间偏移地址 + constexpr static uint32_t BATCH_WRITE_ITEM_OFFSET = + 8U * 1024U; // batchWriteInfo结构体地址相对于windowOut最后1M的偏移 + constexpr static uint32_t BATCH_WRITE_ITEM_SIZE = 32U; + constexpr static uint32_t BLOCK_SIZE = 32U; + constexpr static uint32_t B32_PER_BLOCK = 8U; + constexpr static uint32_t B64_PER_BLOCK = 4U; + constexpr static uint32_t SERVER_RANK_SIZE = 8U; + constexpr static uint32_t IPC_DATA_OFFSET = 4U * 1024U * 1024U; + constexpr static uint32_t RDMA_DATA_SIZE = 300U * 1024U * 1024U; + constexpr static uint32_t EXTRA_TOKEN_INFO_NUM = 4U; // 专家信息 权重信息 量化Scale 到达标志位 + constexpr static uint64_t MB_SIZE = 1024UL * 1024UL; + constexpr static uint32_t MAX_BS = 4096; // 每卡支持的最大bs + + template + __aicore__ inline void SyncFunc() + { + int32_t eventID = static_cast(GetTPipePtr()->FetchEventID(event)); + AscendC::SetFlag(eventID); + AscendC::WaitFlag(eventID); + } + template + inline __aicore__ T RoundUp(const T val, const T align) + { + static_assert(std::is_arithmetic::value, "T must be an arithmetic type"); + if (align == 0 || val + align - 1 < val) { + return val; + } + return (val + align - 1) / align * align; + } + + __aicore__ inline MoeDistributeCombineA2Layered(){}; + __aicore__ inline void Init(GM_ADDR expandX, GM_ADDR expandIdx, GM_ADDR sendCount, GM_ADDR offsetInner, + GM_ADDR offsetOuter, GM_ADDR countOuter, GM_ADDR scales, GM_ADDR XOut, + GM_ADDR workspaceGM, TPipe *pipe, const MoeDistributeCombineA2TilingData *tilingData, + __gm__ void *mc2InitTiling, __gm__ void *mc2CcTiling); + __aicore__ inline void Process(); + +private: + __aicore__ inline void BuffInit(); + __aicore__ inline void SplitCoreCal(); + __aicore__ inline void AlltoAllDispatch(); + __aicore__ inline void SumToWindow(); + __aicore__ inline void SetStatus(); + __aicore__ inline void WaitDispatch(); + __aicore__ inline void AlltoAllServerDispatch(); + __aicore__ inline void SumToServer(); + __aicore__ inline void Preload(); + + TPipe *tpipe_{nullptr}; + GlobalTensor expandXGlobal_; + GlobalTensor expandIdxGlobal_; + GlobalTensor sendCountGlobal_; + GlobalTensor bkCountGlobal_; + GlobalTensor expandScalesGlobal_; + GlobalTensor expandOutGlobal_; + GlobalTensor rankWindow_; // 用于存对端window的变量 + GlobalTensor localOutWindow_; + GlobalTensor localInWindow_; + GlobalTensor bufferIdGlobal_; // 用于存对端状态window的变量 + GlobalTensor statusSpaceGlobal_; // win区状态位置拷入相关参数 + GlobalTensor workspaceGlobal_; // 存储batchWriteInfo结构体信息 + GlobalTensor workspaceGlobal32_; // 存储batchWriteInfo结构体信息 + GlobalTensor readStateGlobal_; + GlobalTensor dstRankStateGlobal_; + LocalTensor batchWriteItemLocalB64; + LocalTensor batchWriteItemLocalB32; + LocalTensor recvCountLocal_; + LocalTensor expertWindowOffsetLocal_; + LocalTensor rowTmpFloatLocal_; + LocalTensor mulBufLocal_; + LocalTensor sumFloatLocal_; + LocalTensor expertIdsLocal_; + LocalTensor expandScalesLocal_; + LocalTensor indexCountsLocal_; + LocalTensor tmpUb_; + uint64_t shareAddreRank[8]; + GlobalTensor selfRankshareMemGlobal_; + + GM_ADDR windowInGM_; + GM_ADDR windowOutGM_; + GM_ADDR statusSpaceGm_; + GM_ADDR expandXGM_; + GM_ADDR expertIdsGM_; + GM_ADDR expandIdxGM_; + GM_ADDR sendCountGM_; + GM_ADDR scalesGM_; + GM_ADDR XOutGM_; + + // 分层所需的参数 + GM_ADDR shareAddrGM_; + GM_ADDR offsetInnerGM_; + GM_ADDR countInnerGM_; + GM_ADDR offsetOuterGM_; + GM_ADDR countOuterGM_; + GM_ADDR recvCountInnerGM_; + GlobalTensor shareAddrGlobal_; + GlobalTensor shareFlagGlobal_; + GlobalTensor shareMemGlobal_; + GlobalTensor dstshareMemGlobal_; + GlobalTensor offsetInnerGlobal_; + // GlobalTensor countInnerGlobal_; + GlobalTensor offsetOuterGlobal_; + GlobalTensor countOuterGlobal_; + GlobalTensor recvCountInnerGlobal_; + TBuf<> offsetReduceBuf_; + TBuf<> countReduceBuf_; + // tiling侧已确保数据上限,相乘不会越界,因此统一采用uint32_t进行处理 + uint32_t countReL{0}; + uint32_t axisBS_{0}; + uint32_t globalBs{0}; + uint32_t axisH_{0}; + uint32_t axisK_{0}; // topK + uint32_t aivNum_{0}; + uint32_t worldSize_{0}; + uint32_t rankId_{0}; + uint32_t coreIdx_{0}; // aiv id + uint32_t sharedExpertRankNum_{0}; // 共享专家卡数 + __gm__ HcclOpResParam *winContext_{nullptr}; + uint32_t moeExpertNum_{0}; // moe专家数, 等于worldSize_ - 共享专家卡数 + uint32_t localMoeExpertNum_{0}; // 每张卡的专家数 + uint32_t expandXRows_; + uint64_t rankSizeOnWin_{0}; + Hccl hccl_; + uint64_t dataOffsetOnWin_{0}; + uint64_t stateOffsetOnWin_{0}; + uint32_t axisHFloatSize_{0}; + uint32_t axisHExpandXTypeSize_{0}; + uint32_t startRankId_{0}; + uint32_t endRankId_{0}; + uint32_t sendRankNum_{0}; + uint32_t halfWinSize_{0}; + uint32_t dataSpaceSize_{0}; + uint32_t bufferId_{0}; + uint32_t tokenNumPerCore_{0}; + uint32_t tokenIndex_{0}; + uint32_t serverNum{0}; + uint32_t ipcSliceSize{0}; + uint32_t ipcSliceNodeSize{0}; + uint64_t send_counts_inner_offset{0}; + uint64_t offset_inner_offset{0}; + uint64_t send_counts_outer_offset{0}; + uint64_t offset_outer_offset{0}; + uint64_t share_offset{0}; + uint32_t IPC_DATA_SIZE{0}; + TQueBind moeQueue_; + TQue moeSumQueue_; + TBuf<> rowTmpFloatBuf_; + TBuf<> sumFloatBuf_; + TBuf<> mulBuf_; + TBuf<> sendCountBuf_; + TBuf<> statusBuf_; + TBuf<> statusSumOutBuf_; + TBuf<> batchWriteItemBuf_; + TBuf<> recvCountBuf_; + TBuf<> scaleBuf_; + TBuf<> expertWindowOffsetBuf_; + int32_t sumTarget_{0}; + int32_t stateValue_{0}; + uint32_t startBs{0}; + uint32_t endBs{0}; + uint32_t processNum{0}; + uint32_t resNum{0}; + uint32_t resLen{0}; + uint32_t offsetIndex{0}; + uint32_t maxLocalBs{0}; + LocalTensor offsetReduceLocal_; + LocalTensor countReduceLocal_; +}; + +template +__aicore__ inline void MoeDistributeCombineA2Layered::Init( + GM_ADDR expandX, GM_ADDR expandIdx, GM_ADDR sendCount, GM_ADDR offsetInner, GM_ADDR offsetOuter, GM_ADDR countOuter, + GM_ADDR scales, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, const MoeDistributeCombineA2TilingData *tilingData, + __gm__ void *mc2InitTiling, __gm__ void *mc2CcTiling) +{ + tpipe_ = pipe; + expandXGM_ = expandX; + expandIdxGM_ = expandIdx; + sendCountGM_ = sendCount; + scalesGM_ = scales; + XOutGM_ = XOut; + rankId_ = tilingData->moeDistributeCombineInfo.epRankId; + axisBS_ = tilingData->moeDistributeCombineInfo.bs; + globalBs = tilingData->moeDistributeCombineInfo.globalBs; + // if (globalBs >= 256U) { + // maxLocalBs = 256U; + // } else { + // maxLocalBs = globalBs; + // } + axisH_ = tilingData->moeDistributeCombineInfo.h; + axisK_ = tilingData->moeDistributeCombineInfo.k; + aivNum_ = tilingData->moeDistributeCombineInfo.aivNum; + moeExpertNum_ = tilingData->moeDistributeCombineInfo.moeExpertNum; + worldSize_ = tilingData->moeDistributeCombineInfo.epWorldSize; + + auto contextGM = AscendC::GetHcclContext(); + winContext_ = (__gm__ HcclOpResParam *)contextGM; + hccl_.Init(contextGM, mc2InitTiling); + hccl_.SetCcTiling(mc2CcTiling); + + halfWinSize_ = RDMA_DATA_SIZE / 2U; + IPC_DATA_SIZE = winContext_->winSize - RDMA_DATA_SIZE - IPC_DATA_OFFSET; + dataSpaceSize_ = halfWinSize_ - STATE_SPACE_SIZE; + windowInGM_ = hccl_.GetWindowsInAddr(rankId_); + bufferIdGlobal_.SetGlobalBuffer((__gm__ uint32_t *)(windowInGM_ + dataSpaceSize_ + worldSize_ * STATE_OFFSET)); + bufferId_ = bufferIdGlobal_(0); + windowInGM_ = windowInGM_ + halfWinSize_ * bufferId_; + windowOutGM_ = hccl_.GetWindowsOutAddr(rankId_) + halfWinSize_ * bufferId_; + coreIdx_ = GetBlockIdx(); + serverNum = worldSize_ / SERVER_RANK_SIZE; + expandXGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)expandX); + expandIdxGlobal_.SetGlobalBuffer((__gm__ ExpandIdxType *)expandIdx); + sendCountGlobal_.SetGlobalBuffer((__gm__ int32_t *)sendCount); + offsetInnerGlobal_.SetGlobalBuffer((__gm__ int32_t *)offsetInner); + countOuterGlobal_.SetGlobalBuffer((__gm__ int32_t *)countOuter); + offsetOuterGlobal_.SetGlobalBuffer((__gm__ int32_t *)offsetOuter); + bkCountGlobal_.SetGlobalBuffer((__gm__ int32_t *)(sendCount + worldSize_ * localMoeExpertNum_ * 4)); + expandScalesGlobal_.SetGlobalBuffer((__gm__ float *)scales); + expandOutGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)XOut); + readStateGlobal_.SetGlobalBuffer((__gm__ int32_t *)(windowOutGM_ + dataSpaceSize_)); + workspaceGlobal_.SetGlobalBuffer((__gm__ uint64_t *)(windowOutGM_ + dataSpaceSize_ + BATCH_WRITE_ITEM_OFFSET)); + workspaceGlobal32_.SetGlobalBuffer((__gm__ uint32_t *)(windowOutGM_ + dataSpaceSize_ + BATCH_WRITE_ITEM_OFFSET)); + localMoeExpertNum_ = moeExpertNum_ / worldSize_; + expandXRows_ = localMoeExpertNum_ * axisBS_ * worldSize_; + rankSizeOnWin_ = static_cast(dataSpaceSize_ / worldSize_ / BLOCK_SIZE * BLOCK_SIZE); + statusSpaceGm_ = windowInGM_ + dataSpaceSize_; + statusSpaceGlobal_.SetGlobalBuffer((__gm__ int32_t *)statusSpaceGm_); + dataOffsetOnWin_ = rankId_ * rankSizeOnWin_; + stateOffsetOnWin_ = static_cast(dataSpaceSize_ + rankId_ * STATE_OFFSET); + axisHFloatSize_ = axisH_ * static_cast(sizeof(float)); + axisHExpandXTypeSize_ = axisH_ * static_cast(sizeof(ExpandXType)); + + uint64_t winSizeMin = + moeExpertNum_ * axisBS_ * (axisHExpandXTypeSize_ + EXTRA_TOKEN_INFO_NUM * axisK_ * sizeof(uint32_t)) + + IPC_DATA_OFFSET + RDMA_DATA_SIZE; // 考虑负载极其不均衡时,HCCL BUFFSIZE需要开的大小 + assert(winContext_->winSize >= winSizeMin, + "The HCCL_BUFFSIZE is %lluMB, the min value should be %lluMB. \ + epWorldSize:%u, epRankId:%u, moeExpertNum:%u, globalBs:%u, bs:%u, k:%u, h:%u, aivNum:%u, \ + totalUbSize:%llu, hcclBufferSize:%u\n", + winContext_->winSize / MB_SIZE, winSizeMin / MB_SIZE, tilingData->moeDistributeCombineInfo.epWorldSize, + tilingData->moeDistributeCombineInfo.epRankId, tilingData->moeDistributeCombineInfo.moeExpertNum, + tilingData->moeDistributeCombineInfo.globalBs, tilingData->moeDistributeCombineInfo.bs, + tilingData->moeDistributeCombineInfo.k, tilingData->moeDistributeCombineInfo.h, + tilingData->moeDistributeCombineInfo.aivNum, tilingData->moeDistributeCombineInfo.totalUbSize, + tilingData->moeDistributeCombineInfo.hcclBufferSize); + + GlobalTensor selfStatusTensor; + selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusSpaceGm_ + SELF_STATE_OFFSET)); + // coreIdx_ < serverNum + int32_t state = selfStatusTensor(coreIdx_ * UB_ALIGN); + if (state == 0) { + sumTarget_ = static_cast(1); + selfStatusTensor(coreIdx_ * UB_ALIGN) = 1; + stateValue_ = 1; + } else { + sumTarget_ = 0; + selfStatusTensor(coreIdx_ * UB_ALIGN) = 0; + stateValue_ = 0; + } + BuffInit(); + SplitCoreCal(); + if (coreIdx_ == 0U) { + readStateGlobal_.SetValue(0, stateValue_); + DataCacheCleanAndInvalid( + readStateGlobal_); + } + send_counts_inner_offset = static_cast(worldSize_ * localMoeExpertNum_); + offset_inner_offset = send_counts_inner_offset + static_cast(globalBs * serverNum); + send_counts_outer_offset = offset_inner_offset + static_cast(globalBs * axisK_ * serverNum); + offset_outer_offset = send_counts_outer_offset + static_cast(axisBS_); + share_offset = offset_outer_offset + static_cast(axisBS_ * serverNum); + + shareAddrGM_ = sendCount + share_offset; + offsetInnerGM_ = sendCount + offset_inner_offset; + countInnerGM_ = sendCount + send_counts_inner_offset; + offsetOuterGM_ = sendCount + offset_outer_offset; + countOuterGM_ = sendCount + send_counts_outer_offset; + + shareAddrGlobal_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendCount) + share_offset); + // offsetInnerGlobal_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendCount) + offset_inner_offset); + // countInnerGlobal_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendCount) + send_counts_inner_offset); + // offsetOuterGlobal_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendCount) + offset_outer_offset); + // countOuterGlobal_.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendCount) + send_counts_outer_offset); + + // LocalTensor sendCountLocal = sendCountBuf_.Get(); + // DataCopy(sendCountLocal, shareAddrGlobal_, RoundUp(SERVER_RANK_SIZE * 2, B32_PER_BLOCK)); // 16 + PipeBarrier(); + for (int i = 0; i < 8; i++) { + shareAddreRank[i] = reinterpret_cast( + RDMA_DATA_SIZE + hccl_.GetWindowsInAddr(rankId_ / SERVER_RANK_SIZE * SERVER_RANK_SIZE + i)); + } +} + +template +__aicore__ inline void MoeDistributeCombineA2Layered::BuffInit() +{ + // tpipe_->InitBuffer(scaleBuf_, 4 * maxLocalBs * sizeof(float)); // 4k + tpipe_->InitBuffer(moeQueue_, BUFFER_NUM, (axisHExpandXTypeSize_ + 32U)); // 7168 * 2 * 2 = 28672 + tpipe_->InitBuffer(statusBuf_, worldSize_ * UB_ALIGN); + tpipe_->InitBuffer(rowTmpFloatBuf_, axisHFloatSize_); + tpipe_->InitBuffer(mulBuf_, axisHFloatSize_); // // 7168 * 4 = 28672 + tpipe_->InitBuffer(sumFloatBuf_, axisHFloatSize_); // // 7168 * 4 = 28672 + tpipe_->InitBuffer(sendCountBuf_, RoundUp(moeExpertNum_ * worldSize_, B32_PER_BLOCK) * + sizeof(int32_t)); // 全局sendCount,为每个专家从不同rank接收的token个数 + tpipe_->InitBuffer(moeSumQueue_, BUFFER_NUM, (axisHExpandXTypeSize_ + 32U)); + tpipe_->InitBuffer(statusSumOutBuf_, sizeof(float)); + tpipe_->InitBuffer(batchWriteItemBuf_, BATCH_WRITE_ITEM_SIZE * worldSize_); + // tpipe_->InitBuffer(offsetReduceBuf_, RoundUp(maxLocalBs * axisK_ * 4, (uint32_t)UB_ALIGN)); // 8k + // tpipe_->InitBuffer(countReduceBuf_, (maxLocalBs + 8) * 4); // 1k + batchWriteItemLocalB64 = batchWriteItemBuf_.Get(); + batchWriteItemLocalB32 = batchWriteItemLocalB64.template ReinterpretCast(); +} +template +__aicore__ inline void MoeDistributeCombineA2Layered::SplitCoreCal() +{ + // 对worldSize按卡分核,得到每个核上处理的卡的数量 + sendRankNum_ = worldSize_ / aivNum_; + uint32_t remainderRankNum = worldSize_ % aivNum_; + startRankId_ = sendRankNum_ * coreIdx_; + if (coreIdx_ < remainderRankNum) { + sendRankNum_++; + startRankId_ += coreIdx_; + } else { + startRankId_ += remainderRankNum; + } + endRankId_ = startRankId_ + sendRankNum_; +} +template +__aicore__ inline void MoeDistributeCombineA2Layered::AlltoAllDispatch() +{ + PRINTF("enter AlltoAllDispatch \n"); + rowTmpFloatLocal_ = rowTmpFloatBuf_.Get(); + ipcSliceSize = IPC_DATA_SIZE / worldSize_; + ipcSliceNodeSize = ipcSliceSize * SERVER_RANK_SIZE; + LocalTensor sendCountLocal = sendCountBuf_.Get(); + // expandScalesLocal_ = scaleBuf_.Get(); + DataCopy(sendCountLocal, sendCountGlobal_, RoundUp(moeExpertNum_ * worldSize_, B32_PER_BLOCK)); + SyncFunc(); + AscendC::DumpTensor(sendCountLocal, 368, 32); + for (uint32_t dstRankId = startRankId_; dstRankId < endRankId_; ++dstRankId) { + // dstRankId 在本机上的同号卡 + uint32_t targetRank = dstRankId % SERVER_RANK_SIZE; + // 计算要发往的目标IPC的地址,不考虑flag偏移 + uint64_t targetRankShareAddr = shareAddreRank[targetRank]; + uint64_t targetRankAddr = + targetRankShareAddr + + static_cast(dstRankId / SERVER_RANK_SIZE * ipcSliceNodeSize + IPC_DATA_OFFSET); + + dstshareMemGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)(targetRankAddr)); + shareFlagGlobal_.SetGlobalBuffer((__gm__ int64_t *)targetRankShareAddr); + // 计算要发送的token数量 + uint32_t rankTokenNum = 0U; + uint32_t serverStartExpId = rankId_ / SERVER_RANK_SIZE * SERVER_RANK_SIZE * localMoeExpertNum_; + for (uint32_t expertId = 0U; expertId < localMoeExpertNum_; ++expertId) { + uint32_t preCount = 0U; + if (expertId != 0U || dstRankId != 0U) { + for (int i = 0; i <= expertId; i++) { + for (int j = 0; j < worldSize_; j++) { + if ((i == expertId) && j >= dstRankId) { + break; + } + // 对epSendCount的前startExpId到expertId - 1的行求和,和第expertId行的前dstRankId - 1求和 + preCount += sendCountLocal.GetValue((i + rankId_ * localMoeExpertNum_) * worldSize_ + j); + } + } + // preCount = static_cast(sendCountLocal.GetValue(expertId * worldSize_ + dstRankId - 1)); // + // expertId专家从dstRankId收到的token在output上的偏移 + } + + uint32_t tokenNum = + sendCountLocal.GetValue((expertId + rankId_ * localMoeExpertNum_) * worldSize_ + dstRankId); + uint32_t startTokenAddr = preCount * axisH_; + PRINTF( + "[AlltoAllDispatch] rank:%d, coreIdx_:%d, expertId:%d, dstRankId:%d, targetRank:%d, tokenNum:%d, " + "preCount:%d\n", + rankId_, coreIdx_, expertId, dstRankId, targetRank, tokenNum, preCount); + // DataCopy(expandScalesLocal_, expandScalesGlobal_[preCount], (tokenNum + 31) / 32 * 32); + SyncFunc(); + uint32_t tokenOffset = 0; + // 从本server的专家起始id到当前expertId + for (int i = serverStartExpId; i < rankId_ * localMoeExpertNum_ + expertId; i++) { + tokenOffset += sendCountLocal.GetValue(i * worldSize_ + dstRankId); + } + for (uint32_t tokenId = 0U; tokenId < tokenNum; ++tokenId) { + float scaleVal = expandScalesGlobal_[preCount].GetValue(tokenId); + LocalTensor InUb = moeQueue_.AllocTensor(); + LocalTensor InUbTemp = InUb[axisH_].template ReinterpretCast(); + InUbTemp(0) = scaleVal; + SyncFunc(); + DataCopy(InUb, expandXGlobal_[startTokenAddr], axisH_); + moeQueue_.EnQue(InUb); + LocalTensor OutUb = moeQueue_.DeQue(); + DataCopy(dstshareMemGlobal_[(tokenOffset + tokenId) * (axisH_ + 16U)], OutUb, axisH_ + 16U); + moeQueue_.FreeTensor(OutUb); + startTokenAddr += axisH_; + rankTokenNum++; + PipeBarrier(); + } + } + PipeBarrier(); + LocalTensor InUb = statusBuf_.AllocTensor(); + InUb.SetValue(0, 12345); + uint32_t flagOffset = rankId_ % SERVER_RANK_SIZE + dstRankId / SERVER_RANK_SIZE * SERVER_RANK_SIZE; + DataCopy(shareFlagGlobal_[flagOffset * 4], InUb, 4); // *4是因为单次拷贝256byte = 4*int64 + statusBuf_.FreeTensor(InUb); + } + SyncAll(); +} + +template +__aicore__ inline void MoeDistributeCombineA2Layered::SumToWindow() +{ + // 当前假设一个core处理一个rank的数据累加,因为已经只剩下同号卡,所以只有serverNum个rank + if (coreIdx_ < serverNum) { + countReL = 0; + shareFlagGlobal_.SetGlobalBuffer((__gm__ int64_t *)shareAddreRank[rankId_ % SERVER_RANK_SIZE]); + LocalTensor InUb = statusBuf_.AllocTensor(); + for (uint32_t i = 0U; i < SERVER_RANK_SIZE; i++) { + uint32_t waitFlagAddr = coreIdx_ * SERVER_RANK_SIZE + i; + while (true) { + DataCopy(InUb, shareFlagGlobal_[waitFlagAddr * 4], 4); + PipeBarrier(); + if (InUb.GetValue(0) == 12345) { + break; + } + } + } + InUb.SetValue(0, 0); + PipeBarrier(); + for (uint32_t i = 0U; i < SERVER_RANK_SIZE; i++) { + DataCopy(shareFlagGlobal_[(coreIdx_ * SERVER_RANK_SIZE + i) * 4], InUb, + 4); // *4是因为单次拷贝256byte = 4*int64 + PipeBarrier(); + } + + statusBuf_.FreeTensor(InUb); + // LocalTensor offsetReduceLocal = offsetReduceBuf_.Get(); + + int32_t targetRankId = coreIdx_ * SERVER_RANK_SIZE + rankId_ % SERVER_RANK_SIZE; + GlobalTensor offsetReduceGt = offsetInnerGlobal_[MAX_BS * moeExpertNum_ * targetRankId]; + // DataCopy(offsetReduceLocal, + // offsetInnerGlobal_[MAX_BS * moeExpertNum_ * rankId_ + (MAX_BS * localMoeExpertNum_ * + // SERVER_RANK_SIZE * coreIdx_)], RoundUp(MAX_BS * localMoeExpertNum_ * SERVER_RANK_SIZE, + // (uint32_t)(UB_ALIGN / sizeof(int32_t)))); + SyncFunc(); + AscendC::DumpTensor(offsetReduceGt, 452, 128); + + uint64_t copyAddr = shareAddreRank[rankId_ % SERVER_RANK_SIZE] + + static_cast(IPC_DATA_OFFSET + coreIdx_ * ipcSliceNodeSize); + shareMemGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)copyAddr); + uint64_t rdmaAddr = (uint64_t)(hccl_.GetWindowsOutAddr(rankId_) + halfWinSize_ * bufferId_ + + coreIdx_ * rankSizeOnWin_ * SERVER_RANK_SIZE); + localOutWindow_.SetGlobalBuffer((__gm__ ExpandXType *)rdmaAddr); + sumFloatLocal_ = sumFloatBuf_.Get(); + offsetIndex = 0U; + for (uint32_t i = 0U; i < MAX_BS; i++) { + bool isTokenInServer = false; + Duplicate(sumFloatLocal_, 0.0f, axisH_); + for (uint32_t j = 0U; j < static_cast(localMoeExpertNum_ * SERVER_RANK_SIZE); j++) { + int32_t expId = j + rankId_ / SERVER_RANK_SIZE * SERVER_RANK_SIZE * localMoeExpertNum_; + int32_t offsetValue = offsetReduceGt.GetValue(i * moeExpertNum_ + expId); + if (offsetValue == -1) continue; + isTokenInServer = true; + tmpUb_ = moeSumQueue_.AllocTensor(); + uint32_t offsetOnIpc = (offsetValue * (axisH_ + 16U) * sizeof(ExpandXType)) / sizeof(ExpandXType); + // uint32_t offsetOnIpc = (offsetValue / (MAX_BS) * ipcSliceSize + + // offsetValue % (MAX_BS) * (axisH_ + 16U) * + // sizeof(ExpandXType)) / sizeof(ExpandXType); + DataCopy(tmpUb_, shareMemGlobal_[offsetOnIpc], axisH_ + 16U); + SyncFunc(); + LocalTensor InUbTemp = tmpUb_[axisH_].template ReinterpretCast(); + float scaleVal = InUbTemp(0); + SyncFunc(); + moeSumQueue_.EnQue(tmpUb_); + LocalTensor tmpOtherUb_ = moeSumQueue_.DeQue(); + Cast(rowTmpFloatLocal_, tmpOtherUb_, AscendC::RoundMode::CAST_NONE, axisH_); + PipeBarrier(); + AscendC::Muls(rowTmpFloatLocal_, rowTmpFloatLocal_, scaleVal, axisH_); + PipeBarrier(); + AscendC::Add(sumFloatLocal_, sumFloatLocal_, rowTmpFloatLocal_, axisH_); + moeSumQueue_.FreeTensor(tmpOtherUb_); + offsetIndex++; + PipeBarrier(); + + PRINTF( + "[SumToWindow] rank:%d, coreId:%d, i:%d, j:%d, offsetValue:%d, globalBs:%d, axisK_:%d, " + "ipcSliceSize:%d, offsetOnIpc:%d, scaleVal:%f, offsetIndex:%d\n", + rankId_, coreIdx_, i, j, offsetValue, globalBs, axisK_, ipcSliceSize, offsetOnIpc, scaleVal, + offsetIndex); + } + PipeBarrier(); + if (!isTokenInServer) { + continue; + } + LocalTensor castUbIn = mulBuf_.Get(); + SyncFunc(); + Cast(castUbIn, sumFloatLocal_, AscendC::RoundMode::CAST_RINT, axisH_); + SyncFunc(); + DataCopy(localOutWindow_[countReL * axisH_], castUbIn, axisH_); + countReL++; + PipeBarrier(); + } + } + + SyncAll(); +} + +template +__aicore__ inline void MoeDistributeCombineA2Layered::AlltoAllServerDispatch() +{ + uint64_t selfTotalNum = 0U; + if (coreIdx_ < serverNum) { + uint32_t tragRankId = rankId_ % SERVER_RANK_SIZE + coreIdx_ * SERVER_RANK_SIZE; + // 目标卡 GetWindowsOutAddr 地址 + uint64_t dstrdmaAddr = (uint64_t)(hccl_.GetWindowsInAddr(tragRankId) + halfWinSize_ * bufferId_ + + (rankId_ / SERVER_RANK_SIZE) * rankSizeOnWin_ * SERVER_RANK_SIZE); + uint64_t srcrdmaAddr = (uint64_t)(hccl_.GetWindowsOutAddr(rankId_) + halfWinSize_ * bufferId_ + + coreIdx_ * rankSizeOnWin_ * SERVER_RANK_SIZE); + + // countReL + batchWriteItemLocalB64(0) = srcrdmaAddr; + batchWriteItemLocalB64(0 + 1) = dstrdmaAddr; + if (coreIdx_ == (rankId_ / SERVER_RANK_SIZE)) { + batchWriteItemLocalB64(0 + 2) = 0; + } else { + batchWriteItemLocalB64(0 + 2) = countReL * axisH_; + } + batchWriteItemLocalB32(0 + 6) = HcclDataType::HCCL_DATA_TYPE_FP16; + batchWriteItemLocalB32(0 + 7) = tragRankId; + + SyncFunc(); + DataCopy(workspaceGlobal_[coreIdx_ * 4], batchWriteItemLocalB64, 4); + } + SyncAll(); + if (coreIdx_ == 0U) { + HcclHandle handleId = hccl_.BatchWrite((GM_ADDR)(workspaceGlobal_.GetPhyAddr()), serverNum); + bufferIdGlobal_(0) = bufferId_ ^ 1; + } + if (coreIdx_ == (rankId_ / SERVER_RANK_SIZE)) { + uint64_t srcrdmaAddr = (uint64_t)(hccl_.GetWindowsOutAddr(rankId_) + + halfWinSize_ * bufferId_ + //(rankId_ % SERVER_RANK_SIZE) * rankSizeOnWin_); + (rankId_ / SERVER_RANK_SIZE) * rankSizeOnWin_ * SERVER_RANK_SIZE); + uint64_t dstrdmaAddr = (uint64_t)(hccl_.GetWindowsInAddr(rankId_) + + halfWinSize_ * bufferId_ + //(rankId_ % SERVER_RANK_SIZE) * rankSizeOnWin_); + (rankId_ / SERVER_RANK_SIZE) * rankSizeOnWin_ * SERVER_RANK_SIZE); + + localInWindow_.SetGlobalBuffer((__gm__ ExpandXType *)(dstrdmaAddr)); + localOutWindow_.SetGlobalBuffer((__gm__ ExpandXType *)(srcrdmaAddr)); + + for (uint32_t tokenId = 0U; tokenId < countReL; ++tokenId) { + LocalTensor InUb = moeQueue_.AllocTensor(); + DataCopy(InUb, localOutWindow_[tokenId * axisH_], axisH_); + moeQueue_.EnQue(InUb); + LocalTensor OutUb = moeQueue_.DeQue(); + DataCopy(localInWindow_[tokenId * axisH_], OutUb, axisH_); + moeQueue_.FreeTensor(OutUb); + } + } +} + +template +__aicore__ inline void MoeDistributeCombineA2Layered::SetStatus() +{ + if (coreIdx_ != 0U) { + SyncAll(); + return; + } + + uint32_t selfServerID = rankId_ / SERVER_RANK_SIZE; + for (uint32_t serverID = 0U; serverID < serverNum; serverID++) { + uint32_t targetRank = rankId_ % SERVER_RANK_SIZE + serverID * SERVER_RANK_SIZE; + batchWriteItemLocalB64(serverID * 4) = (uint64_t)(readStateGlobal_.GetPhyAddr()); + batchWriteItemLocalB64(serverID * 4 + 1) = + (uint64_t)(hccl_.GetWindowsInAddr(targetRank) + halfWinSize_ * bufferId_ + dataSpaceSize_ + + selfServerID * STATE_OFFSET); + batchWriteItemLocalB64(serverID * 4 + 2) = 8; + batchWriteItemLocalB32(serverID * 8 + 6) = HcclDataType::HCCL_DATA_TYPE_INT32; + batchWriteItemLocalB32(serverID * 8 + 7) = targetRank; + } + SyncFunc(); + DataCopy(workspaceGlobal_[serverNum * 4], batchWriteItemLocalB64, 4 * (serverNum)); + GlobalTensor localStateGlobal; + localStateGlobal.SetGlobalBuffer((__gm__ int32_t *)(windowInGM_ + dataSpaceSize_ + selfServerID * STATE_OFFSET)); + localStateGlobal.SetValue(0, stateValue_); + DataCacheCleanAndInvalid( + localStateGlobal); + SyncFunc(); + if ASCEND_IS_AIV { + HcclHandle handleId = + hccl_.BatchWrite((GM_ADDR)(workspaceGlobal_[serverNum * 4].GetPhyAddr()), serverNum); + } + SyncAll(); +} + +template +__aicore__ inline void MoeDistributeCombineA2Layered::WaitDispatch() +{ + if (coreIdx_ < serverNum) { + uint32_t targetRank = rankId_ % SERVER_RANK_SIZE + (coreIdx_)*SERVER_RANK_SIZE; + LocalTensor statusTensor = statusBuf_.Get(); + uint32_t readNum = 1U; + DataCopyParams intriParams{static_cast(readNum), 1, 15, 0}; // srcStride为15个block + while (true) { + DataCopy(statusTensor, statusSpaceGlobal_[(coreIdx_)*STATE_OFFSET / sizeof(int32_t)], intriParams); + PipeBarrier(); + int32_t sumOfFlag = statusTensor.GetValue(0); + + if (sumOfFlag == sumTarget_) { + break; + } + } + } + + SyncAll(); +} + +template +__aicore__ inline void MoeDistributeCombineA2Layered::Preload() +{ + if (coreIdx_ >= 8U) { + return; + } + processNum = MAX_BS / 8U; + resNum = MAX_BS - processNum * 8U; + resLen = (resNum == 0U) ? 0U : 1U; + startBs = 0U; + endBs = 0U; + if (coreIdx_ < resNum) { + processNum += 1U; + startBs = coreIdx_ * processNum; + endBs = startBs + processNum; + } else { + startBs = coreIdx_ * processNum + resNum; + endBs = startBs + processNum; + } + uint64_t selfRankAddr = (uint64_t)(hccl_.GetWindowsInAddr(rankId_) + halfWinSize_ * bufferId_); + localInWindow_.SetGlobalBuffer((__gm__ ExpandXType *)(selfRankAddr)); + // offsetReduceLocal_ = offsetReduceBuf_.Get(); + // countReduceLocal_ = countReduceBuf_.Get(); + // DataCopy( + // offsetReduceLocal_, offsetOuterGlobal_, RoundUp(axisBS_ * serverNum, (uint32_t)(UB_ALIGN / + // sizeof(int32_t)))); + // DataCopy(countReduceLocal_, countOuterGlobal_, RoundUp(axisBS_, (uint32_t)(UB_ALIGN / sizeof(int32_t)))); + SyncFunc(); + offsetIndex = 0U; + sumFloatLocal_ = sumFloatBuf_.Get(); + + if (startBs != 0U) { + offsetIndex = countOuterGlobal_.GetValue(startBs - 1U); + } +} +template +__aicore__ inline void MoeDistributeCombineA2Layered::SumToServer() +{ + if (coreIdx_ >= 8U) { + SyncAll(); + return; + } + uint32_t count = startBs; + for (uint32_t i = startBs; i < endBs; i++) { + // int offsetPre = 0; + // int offsetCur = countOuterGlobal_.GetValue(i); + // if (i != 0U) { + // offsetPre = countOuterGlobal_.GetValue(i - 1); + // } + // int copyNum = offsetCur - offsetPre; + // if (!copyNum) { + // break; + // } + int flag = 0; + Duplicate(sumFloatLocal_, 0.0f, axisH_); + for (int j = 0; j < serverNum; j++) { + int cntOuter = offsetOuterGlobal_.GetValue(i * serverNum + j); + if (cntOuter == -1) { + continue; + } + tmpUb_ = moeSumQueue_.AllocTensor(); + flag = 1; + int offsetOnIpc = (cntOuter * axisH_ * sizeof(ExpandXType)) / sizeof(ExpandXType); + uint64_t selfRankAddr = (uint64_t)(hccl_.GetWindowsInAddr(rankId_) + halfWinSize_ * bufferId_ + + j * rankSizeOnWin_ * SERVER_RANK_SIZE); + localInWindow_.SetGlobalBuffer((__gm__ ExpandXType *)(selfRankAddr)); + DataCopy(tmpUb_, localInWindow_[offsetOnIpc], axisH_); + moeSumQueue_.EnQue(tmpUb_); + LocalTensor tmpOtherUb_ = moeSumQueue_.DeQue(); + // cast before muls + Cast(rowTmpFloatLocal_, tmpOtherUb_, AscendC::RoundMode::CAST_NONE, axisH_); + PipeBarrier(); + // add mulBufLocal to sumFloatBufLocal + AscendC::Add(sumFloatLocal_, sumFloatLocal_, rowTmpFloatLocal_, axisH_); + moeSumQueue_.FreeTensor(tmpOtherUb_); + } + PipeBarrier(); + if (!flag) { + continue; + } + LocalTensor castUbIn = mulBuf_.Get(); + SyncFunc(); + Cast(castUbIn, sumFloatLocal_, AscendC::RoundMode::CAST_RINT, axisH_); + SyncFunc(); + DataCopy(expandOutGlobal_[count * axisH_], castUbIn, axisH_); + count++; + PipeBarrier(); + } + + SyncAll(); +} + +template +__aicore__ inline void MoeDistributeCombineA2Layered::Process() +{ + printf("step0\n"); + if ASCEND_IS_AIV { + printf("step1\n"); + AlltoAllDispatch(); // 所有核执行 + printf("step2\n"); + SumToWindow(); // 前serverNum个核执行 + printf("step3\n"); + AlltoAllServerDispatch(); // 前serverNum个核执行 + printf("step4\n"); + SetStatus(); // 0核执行 + printf("step5\n"); + Preload(); // 前8核执行 + printf("step6\n"); + WaitDispatch(); // 前serverNum个核执行 + printf("step7\n"); + SumToServer(); // 前8核执行 + printf("step8\n"); + hccl_.Finalize(); + printf("step9\n"); + } +} + +} // namespace MoeDistributeCombineA2Impl +#endif // MOE_DISTRIBUTE_COMBINE_A2_LAYERED_H diff --git a/csrc/deepep/ops2/op_kernel/moe_distribute_combine_a2_tiling.h b/csrc/deepep/ops2/op_kernel/moe_distribute_combine_a2_tiling.h new file mode 100644 index 00000000..4b00995c --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/moe_distribute_combine_a2_tiling.h @@ -0,0 +1,31 @@ +#ifndef MOE_DISTRIBUTE_CMOBINE_A2_TILING_H +#define MOE_DISTRIBUTE_CMOBINE_A2_TILING_H + +#include +#include "kernel_tiling/kernel_tiling.h" + +struct MoeDistributeCombineA2Info { + uint32_t epWorldSize; // epWorldSize + uint32_t tpWorldSize; // tpWorldSize + uint32_t epRankId; // epRankId + uint32_t tpRankId; // tpRankId + uint32_t expertSharedType; // expert type + uint32_t sharedExpertRankNum; // shared expert number + uint32_t moeExpertNum; // moe expert number + uint32_t globalBs; // globalBs = BS * worldSize + uint32_t bs; // bs + uint32_t k; // k + uint32_t h; // h + uint32_t aivNum; // aivNum + uint64_t totalUbSize; // epWorldSize + uint32_t hcclBufferSize; // HCCL windows, unit:B + uint32_t rsd; +}; + +struct MoeDistributeCombineA2TilingData { + Mc2InitTiling mc2InitTiling; + Mc2CcTiling mc2CcTiling; + MoeDistributeCombineA2Info moeDistributeCombineInfo; +}; + +#endif //__MOE_DISTRIBUTE_CMOBINE_A2_TILING_H__ diff --git a/csrc/deepep/ops2/op_kernel/notify_dispatch_a2.cpp b/csrc/deepep/ops2/op_kernel/notify_dispatch_a2.cpp new file mode 100644 index 00000000..064d4de9 --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/notify_dispatch_a2.cpp @@ -0,0 +1,71 @@ +#include "kernel_operator.h" +#include "notify_dispatch_a2.h" +#include "notify_dispatch_tiling_a2.h" + +#define TILING_KEY_FLOAT16 20 +#define TILING_KEY_BFLOAT16 21 +#define TILING_KEY_FLOAT 22 +#define TILING_KEY_INT 23 +#define TILING_KEY_A2_FLOAT16 120 +#define TILING_KEY_A2_BFLOAT16 121 +#define TILING_KEY_A2_FLOAT 122 +#define TILING_KEY_A2_INT 123 + +#define KERNEL_USE_WORKSPACE (1 * 1024 * 1024) + +extern "C" __global__ __aicore__ void notify_dispatch_a2(GM_ADDR sendData, GM_ADDR tokenPerExpertData, GM_ADDR tmpData, + GM_ADDR sendDataOffset, GM_ADDR recvData, + GM_ADDR tokenServerIdx, GM_ADDR tokensUniquePerServer, + GM_ADDR epRankTokenCnt, GM_ADDR localEpTokenCnt, + GM_ADDR srcOffsetRankTokenIdx, GM_ADDR dstOffsetRankTokenIdx, + GM_ADDR offsetInner, GM_ADDR countOuter, GM_ADDR expandIdx, + GM_ADDR workspace, GM_ADDR tiling) +{ + REGISTER_TILING_DEFAULT(NotifyDispatchA2TilingData); + GET_TILING_DATA_WITH_STRUCT(NotifyDispatchA2TilingData, tilingData, tiling); + + // hcomm will set magic later in init + uint32_t magic = 1; + GM_ADDR commArgs = nullptr; + + int localRank = tilingData.notifyDispatchInfoA2.localRankId; + int localRankSize = tilingData.notifyDispatchInfoA2.localRankSize; + int rank = tilingData.notifyDispatchInfoA2.rankId; + int rankSize = tilingData.notifyDispatchInfoA2.rankSize; + int64_t len = tilingData.notifyDispatchInfoA2.sendCount; + int64_t numTokens = tilingData.notifyDispatchInfoA2.numTokens; + int64_t topkNum = tilingData.notifyDispatchInfoA2.topkNum; + int64_t numExperts = tilingData.notifyDispatchInfoA2.numExperts; + + GM_ADDR sendDataInput = sendData; + GM_ADDR tokenPerExpertDataInput = tokenPerExpertData; + GM_ADDR tmpDataInput = tmpData; + + GM_ADDR sendDataOffsetOutput = sendDataOffset; + GM_ADDR recvDataOutput = recvData; + GM_ADDR tokenServerIdxOutput = tokenServerIdx; + GM_ADDR tokensUniquePerServerOutput = tokensUniquePerServer; + GM_ADDR epRankTokenCntOutput = epRankTokenCnt; + GM_ADDR localEpTokenCntOutput = localEpTokenCnt; + GM_ADDR srcOffsetRankTokenIdxOutput = srcOffsetRankTokenIdx; + GM_ADDR dstOffsetRankTokenIdxOutput = dstOffsetRankTokenIdx; + GM_ADDR offsetInnerOutput = offsetInner; + GM_ADDR countOuterOutput = countOuter; + GM_ADDR expandIdxOutput = expandIdx; + + // fill in unused args + uint32_t extraFlag = 0; + GM_ADDR scale = nullptr; + int root = 0; + int op = 0; + int cycleCount = 0; + int64_t scaleCount = 0; + GM_ADDR offset = nullptr; + int blockNum = GetBlockNum(); + + if (TILING_KEY_IS(TILING_KEY_A2_INT)) { + NotifyDispatchA2 opKernel(rank, rankSize, extraFlag); + opKernel.Init(KERNELS_ARGS_CALL_A2_ALL2ALL()); + opKernel.Process(); + } +} diff --git a/csrc/deepep/ops2/op_kernel/notify_dispatch_a2.h b/csrc/deepep/ops2/op_kernel/notify_dispatch_a2.h new file mode 100644 index 00000000..01fadb60 --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/notify_dispatch_a2.h @@ -0,0 +1,1292 @@ +#ifndef NOTIFY_DISPATCH_A2_H +#define NOTIFY_DISPATCH_A2_H + +#include +#include "kernel_operator.h" + +#include "comm_args.h" +#include "data_copy.h" +#include "sync_collectives.h" +#include "moe_distribute_base.h" +#include "notify_dispatch_tiling_a2.h" + +using namespace AscendC; +using namespace Moe; + +#define KERNELS_ARGS_FUN_A2_ALL2ALL() \ + GM_ADDR sendDataInput, GM_ADDR tokenPerExpertDataInput, GM_ADDR tmpDataInput, GM_ADDR sendDataOffsetOutput, \ + GM_ADDR recvDataOutput, int64_t len, int64_t numTokens, int64_t topkNum, int64_t numExperts, int op, int root, \ + int cycleCount, GM_ADDR scale, int64_t scaleCount, GM_ADDR offset, int localRank, int localRankSize, \ + GM_ADDR commArgs, GM_ADDR tokenServerIdxOutput, GM_ADDR tokensUniquePerServerOutput, \ + GM_ADDR epRankTokenCntOutput, GM_ADDR localEpTokenCntOutput, GM_ADDR srcOffsetRankTokenIdxOutput, \ + GM_ADDR dstOffsetRankTokenIdxOutput, GM_ADDR offsetInnerOutput, GM_ADDR countOuterOutput, \ + GM_ADDR expandIdxOutput, GM_ADDR workspace, GM_ADDR tiling + +#define KERNELS_ARGS_CALL_A2_ALL2ALL() \ + sendDataInput, tokenPerExpertDataInput, tmpDataInput, sendDataOffsetOutput, recvDataOutput, len, numTokens, \ + topkNum, numExperts, op, root, cycleCount, scale, scaleCount, offset, localRank, localRankSize, commArgs, \ + tokenServerIdxOutput, tokensUniquePerServerOutput, epRankTokenCntOutput, localEpTokenCntOutput, \ + srcOffsetRankTokenIdxOutput, dstOffsetRankTokenIdxOutput, offsetInnerOutput, countOuterOutput, \ + expandIdxOutput, workspace, tiling + +#define printflag(ss) \ + if (blockIdx < coreNumBetween) { \ + printf("========rank:%d coreIdx:%d " #ss "\n", rank, blockIdx); \ + } + +template +__aicore__ inline void SyncFunc() +{ + int32_t eventID = static_cast(GetTPipePtr()->FetchEventID(event)); + AscendC::SetFlag(eventID); + AscendC::WaitFlag(eventID); +} + +template +class NotifyDispatchA2 +{ + constexpr static int32_t MAX_CORE_NUM = 20; + constexpr static int64_t MULTI_RANK_SIZE = 4; // 每个core最多往4个rank发送数据,64卡场景 + constexpr static int64_t MAX_RANK_SIZE = 64; // 910B设备本算子最大支持的rank数,64卡场景 + constexpr static int32_t INVALID_RANK = -1; + constexpr static uint32_t TEMP_BUF_LEN = 128 * 1024; // tuf注册长度为128K,剩余部分注册为其他buffer + constexpr static uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024; // 对齐tiling + + constexpr static uint32_t BW_ITEM_SIZE = 32; // = sizeof(BatchWriteItem) + constexpr static uint32_t U64_PER_ITEM = BW_ITEM_SIZE / sizeof(uint64_t); // 每个BatchWriteItem占多少个unit64 + constexpr static uint32_t U32_PER_ITEM = BW_ITEM_SIZE / sizeof(uint32_t); // 每个BatchWriteItem占多少个unit32 + constexpr static uint32_t BW_MEB_OFFSET64_LOCAL_GM = 0; // BatchWriteItem成员变量offset,按照sizeof(unit64)计算 + constexpr static uint32_t BW_MEB_OFFSET64_REMOTE_GM = 1; // BatchWriteItem成员变量offset,按照sizeof(unit64)计算 + constexpr static uint32_t BW_MEB_OFFSET64_DATA_SIZE = 2; // BatchWriteItem成员变量offset,按照sizeof(unit64)计算 + constexpr static uint32_t BW_MEB_OFFSET32_DATA_TYPE = 6; // BatchWriteItem成员变量offset,按照sizeof(unit32)计算 + constexpr static uint32_t BW_MEB_OFFSET32_TARGET_RANK = 7; // BatchWriteItem成员变量offset,按照sizeof(unit32)计算 + + constexpr static int32_t FLAG_VALUE = 0xFFFFFFFF; + constexpr static uint32_t STATUS_ENTRY_SIZE = 32; // 每个status entry占用的空间大小, bytes + constexpr static uint32_t U32_STATUS_ENTRY = STATUS_ENTRY_SIZE / sizeof(int32_t); + constexpr static uint32_t FLAG_OFFSET = 8; // status_flag 在 statusTensor中的offset, bytes + constexpr static uint32_t SOURCE_RANK_OFFSET = 16; // sourceRankId 在 statusTensor中的offset, bytes + constexpr static uint32_t DEST_RANK_OFFSET = 20; // destRankId 在 statusTensor中的offset, bytes + constexpr static uint32_t DATALEN_OFFSET = 24; // dataLen 在 statusTensor中的offset, bytes + constexpr static uint32_t UB_ALIGN = 32; // UB按32字节对齐 + constexpr static uint32_t EXP_TOKEN_COUNT_FLAG_CNT = UB_ALIGN / sizeof(int32_t); // 8 + constexpr static uint32_t GM_ALIGN = 64; // GM按64字节对齐 + + constexpr static uint32_t MAX_BS = 4096; // 每卡支持的最大bs + +public: + __aicore__ inline NotifyDispatchA2(int rank, int rankSize, uint32_t extraFlag) + : rank(rank), rankSize(rankSize), extraFlag(extraFlag) + {} + + __aicore__ inline void Init(KERNELS_ARGS_FUN_A2_ALL2ALL()) + { + InitAll2AllLayeredRdma(KERNELS_ARGS_CALL_A2_ALL2ALL()); + + tokenPerExpertDataAlignLen = Ceil(numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + sendDataOffsetAlignLen = Ceil(numExperts * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + sendDataAlignLen = Ceil(len * sizeof(T), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; // TODO: 数据长度 + perRankDataNum = len; // 发送所有数据 + + InitTensorLen(); + + InitShare(); + // 初始化核分组, 需要外部调用保证所有的server的localRankSize均相同 + serverNum = CeilDiv(rankSize, localRankSize); + serverId = rank / localRankSize; + // printf("rank:%d coreIdx:%d rankSize:%d localRankSize:%d serverNum:%d serverId:%d\n", rank, blockIdx, + // rankSize, localRankSize, serverNum, serverId); + InitCoreGroup(); + // 初始化目标rank列表 + InitTargetRank(); + // 初始化数据切片 + InitDataSlice(); + + this->sendDataInput = (__gm__ T *)sendDataInput; + this->tokenPerExpertDataInput = (__gm__ int32_t *)tokenPerExpertDataInput; + this->tmpDataInput = (__gm__ int32_t *)tmpDataInput; + this->sendDataOffsetOutput = (__gm__ T *)sendDataOffsetOutput; + this->recvDataOutput = (__gm__ T *)recvDataOutput; + this->epRankTokenCntOutputGM_ = (__gm__ int32_t *)epRankTokenCntOutput; + + sendDataInputGt.SetGlobalBuffer((__gm__ T *)sendDataInput); + tokenPerExpertDataInputGt.SetGlobalBuffer((__gm__ int32_t *)tokenPerExpertDataInput); + tmpDataInputGt.SetGlobalBuffer((__gm__ int32_t *)tmpDataInput); + sendDataOffsetOutputGt.SetGlobalBuffer((__gm__ T *)sendDataOffsetOutput); + recvDataOutputGt.SetGlobalBuffer((__gm__ T *)recvDataOutput); + + tokenServerIdxOutputGT_.SetGlobalBuffer((__gm__ int32_t *)tokenServerIdxOutput); + tokensUniquePerServerOutputGT_.SetGlobalBuffer((__gm__ int32_t *)tokensUniquePerServerOutput); + epRankTokenCntOutputGT_.SetGlobalBuffer((__gm__ int32_t *)epRankTokenCntOutput); + localEpTokenCntOutputGT_.SetGlobalBuffer((__gm__ int32_t *)localEpTokenCntOutput); + srcOffsetRankTokenIdxOutputGT_.SetGlobalBuffer((__gm__ int32_t *)srcOffsetRankTokenIdxOutput); + dstOffsetRankTokenIdxOutputGT_.SetGlobalBuffer((__gm__ int32_t *)dstOffsetRankTokenIdxOutput); + offsetInnerOutputGT_.SetGlobalBuffer((__gm__ int32_t *)offsetInnerOutput); + countOuterOutputGT_.SetGlobalBuffer((__gm__ int32_t *)countOuterOutput); + expandIdxOutputGT_.SetGlobalBuffer((__gm__ int32_t *)expandIdxOutput); + + // 初始化RDMA相关变量 + // dataSpaceGT_ = workspace; // 需要预留大一些空间供存放交换后拆分出来的数据 + windowInGM_ = this->shareAddrs[rank]; + windowOutGM_ = hccl_.GetWindowsOutAddr(rank) + (magic % PING_PONG_SIZE) * IPC_BUFF_MAX_SIZE; + batchWriteInfoTensor_.SetGlobalBuffer((__gm__ uint32_t *)(workspace), rankSize * U32_PER_ITEM); + // batchWriteInfoTensor_.SetGlobalBuffer((__gm__ uint32_t*)(epRankTokenCntOutputGM_), rankSize * U32_PER_ITEM); + // // 出参地址临时使用 + windowInstatusTensor_.SetGlobalBuffer((__gm__ int32_t *)(windowInGM_ + IPC_DATA_OFFSET)); + windowInTensor_.SetGlobalBuffer((__gm__ T *)(windowInGM_ + IPC_DATA_OFFSET)); + windowOutstatusTensor_.SetGlobalBuffer((__gm__ int32_t *)(windowOutGM_ + IPC_DATA_OFFSET)); + windowOutTensor_.SetGlobalBuffer((__gm__ T *)(windowOutGM_ + IPC_DATA_OFFSET)); + + pipe.InitBuffer(batchWriteInfoBuf_, rankSize * BW_ITEM_SIZE); + pipe.InitBuffer(tempBuf_, UB_ALIGN); // 存放临时的立即数 + pipe.InitBuffer(statusBuf_, rankSize * STATUS_ENTRY_SIZE); // rankSize * 32B + statusTensor_ = statusBuf_.Get(); // 保存发送数据量及flag,同时用于计算windows中的偏移 + Duplicate(statusTensor_, 0, rankSize * STATUS_ENTRY_SIZE); + } + + __aicore__ inline void Process() + { + if ASCEND_IS_AIV { + // 第一阶段,处理server间通信 + if (serverNum > 1) { + ProcessBetweenServer(); + } + printflag("beforeProcessWithinServer\n"); + + // 第二阶段,处理server内通信 + ProcessWithinServer(); + SyncAll(); + + printflag("beforeSplitAndCalcData\n"); + // 交换后的数据拆分和计算输出 + SplitAndCalcData(); // TODO: 先验证recv_data + SyncAll(); + + printflag("beforeFinalize\n"); + hccl_.Finalize(); + printflag("AfterFinalize\n"); + } + // if (blockIdx == 0) { + // AscendC::DumpTensor(tokenServerIdxOutputGT_, 166, 16); + // AscendC::DumpTensor(countOuterOutputGT_, 167, 16); + // AscendC::DumpTensor(expandIdxOutputGT_, 168, 16); + // } + PRINTF("[notify] rank:%d, block:%d \n", rank, blockIdx); + } + +private: + FORCE_INLINE_AICORE void InitAll2AllLayeredRdma(KERNELS_ARGS_FUN_A2_ALL2ALL()) + { + this->root = 0; + this->len = len; + this->numExperts = numExperts; + this->numTokens = numTokens; + this->topkNum = topkNum; + this->scale = nullptr; + this->magic = 0; + this->localRank = localRank; + this->localRankSize = localRankSize; + this->xRankSize = localRankSize; + this->yRankSize = rankSize / localRankSize; + this->xRankIdx = rank % localRankSize; + this->yRankIdx = rank / localRankSize; + this->blockIdx = GetBlockIdx(); + this->blockNum = GetBlockNum(); + uint8_t ctxIdx; + + ctxIdx = COMM_EP_IDX; + + // 初始化RDMA相关变量 + auto tilingData = (__gm__ NotifyDispatchA2TilingData *)tiling; + __gm__ void *mc2InitTiling = (__gm__ void *)(&(tilingData->mc2InitTiling)); + __gm__ void *mc2CcTiling = (__gm__ void *)(&(tilingData->mc2CcTiling1)); + + auto contextGM0 = AscendC::GetHcclContext(); + + hccl_.Init(contextGM0, mc2InitTiling); + hccl_.SetCcTiling(mc2CcTiling); + this->winContext_[COMM_EP_IDX] = (__gm__ HcclOpResParam *)contextGM0; + + // 设置并自增magic + magicTensor_.SetGlobalBuffer((__gm__ int32_t *)(hccl_.GetWindowsInAddr(rank) + IPC_DATA_OFFSET - + blockNum * sizeof(int32_t) * EXP_TOKEN_COUNT_FLAG_CNT)); + + pipe.InitBuffer(this->tBuf, TEMP_BUF_LEN); + LocalTensor tempLocal = tBuf.Get(); + tempLocal(0) = 1; + // 使用atomic方式实现+1 + AscendC::SetAtomicAdd(); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); // 等待SetValue完成 + DataCopy(magicTensor_[blockIdx * EXP_TOKEN_COUNT_FLAG_CNT], tempLocal, EXP_TOKEN_COUNT_FLAG_CNT); + AscendC::SetAtomicNone(); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); // 等待DataCopy完成 + magic = magicTensor_.GetValue(blockIdx * EXP_TOKEN_COUNT_FLAG_CNT); + PipeBarrier(); + // 初始化目标rank的shareAddrs + for (int i = 0; i < rankSize; i++) { + this->shareAddrs[i] = hccl_.GetWindowsInAddr(i) + (magic % PING_PONG_SIZE) * IPC_BUFF_MAX_SIZE; + } + + sync.Init(this->rank, this->rankSize, this->shareAddrs, tBuf); + } + + template + FORCE_INLINE_AICORE void CpGM2GMPingPong(int64_t dataSizeRemain, const GlobalTensor &sendDataInputGt, + const GlobalTensor &recvDataOutputGT, int op); + template + FORCE_INLINE_AICORE void SetAtomic(int op); + FORCE_INLINE_AICORE void UnsetAtomic(int op); + template + FORCE_INLINE_AICORE void SetWaitEvent(event_t eventId); + + __aicore__ inline void InitTensorLen() + { + numTokensPerExpertAlignLen = Ceil(numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + gNumTokensPerExpertAlignLen = Ceil(rankSize * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + + numTokensUniquePerServerAlignLen = Ceil(serverNum * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + gNumTokensUniquePerServerAlignLen = Ceil(rankSize * serverNum * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + + numTokensPerServerAlignLen = Ceil(MAX_BS * serverNum * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + gNumTokensPerServerAlignLen = + Ceil(rankSize * MAX_BS * serverNum * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + + tokenServerCntAlignLen = Ceil(MAX_BS * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + gTokenServerCntAlignLen = Ceil(rankSize * MAX_BS * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + + tokenServerIdxAlignLen = Ceil(MAX_BS * serverNum * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + gTokenServerIdxAlignLen = Ceil(rankSize * MAX_BS * serverNum * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + + tokenExpertIdxAlignLen = Ceil(MAX_BS * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + gTokenExpertIdxAlignLen = Ceil(rankSize * MAX_BS * numExperts * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + + expertMaxBsSrcOffsetAlignLen = Ceil(numExperts * MAX_BS * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + gExpertMaxBsSrcOffsetAlignLen = + Ceil(rankSize * numExperts * MAX_BS * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + + expertMaxBsOriOffsetAlignLen = Ceil(numExperts * MAX_BS * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + gExpertMaxBsOriOffsetAlignLen = + Ceil(rankSize * numExperts * MAX_BS * sizeof(int32_t), UB_ALIGN_SIZE) * UB_ALIGN_SIZE; + + /* + if (blockIdx == 0) { + PRINTF("[InitTensorLen] rank:%d, blockIdx:%d, send_count:%d, numExperts:%d, rankSize:%d, serverNum:%d, + MAX_BS:%d, \ + numTokensPerExpertAlignLen:%d, gNumTokensPerExpertAlignLen:%d, \ + numTokensUniquePerServerAlignLen:%d, gNumTokensUniquePerServerAlignLen:%d, \ + numTokensPerServerAlignLen:%d, gNumTokensPerServerAlignLen:%d, \ + tokenServerCntAlignLen:%d, gTokenServerCntAlignLen:%d, \ + tokenServerIdxAlignLen:%d, gTokenServerIdxAlignLen:%d, \ + tokenExpertIdxAlignLen:%d, gTokenExpertIdxAlignLen:%d, \ + expertMaxBsSrcOffsetAlignLen:%d, gExpertMaxBsSrcOffsetAlignLen:%d, \ + expertMaxBsOriOffsetAlignLen:%d, gExpertMaxBsOriOffsetAlignLen:%d \n", + rank, blockIdx, len, numExperts, rankSize, serverNum, MAX_BS, + numTokensPerExpertAlignLen, gNumTokensPerExpertAlignLen, + numTokensUniquePerServerAlignLen, gNumTokensUniquePerServerAlignLen, + numTokensPerServerAlignLen, gNumTokensPerServerAlignLen, + tokenServerCntAlignLen, gTokenServerCntAlignLen, + tokenServerIdxAlignLen, gTokenServerIdxAlignLen, + tokenExpertIdxAlignLen, gTokenExpertIdxAlignLen, + expertMaxBsSrcOffsetAlignLen, gExpertMaxBsSrcOffsetAlignLen, + expertMaxBsOriOffsetAlignLen, gExpertMaxBsOriOffsetAlignLen); + } + */ + } + + __aicore__ inline void InitShare() + { + int64_t queNum = MAX_CORE_NUM; + queElemLen = (IPC_BUFF_MAX_SIZE - IPC_DATA_OFFSET) / sizeof(T) / queNum; // 计算共享队列元素大小 + queSize = (queElemLen * sizeof(T) / GM_ALIGN) * GM_ALIGN; // GM 64字节对齐 + queLen = queSize / sizeof(T); // 一个que的可放入的元素数量 + } + + __aicore__ inline void InitCoreGroup() + { + coreNumBetween = (rankSize <= MAX_CORE_NUM) ? rankSize : MAX_CORE_NUM; + coreNumWithin = (rankSize <= MAX_CORE_NUM) ? rankSize : MAX_CORE_NUM; + rankNumPerCore = CeilDiv(rankSize, MAX_CORE_NUM); // 每个核负责的rank数 + } + + // 计算通信目标,分两个阶段: + // 阶段一:处理Server间通信,Server间的同号卡之间进行Pair-wise的通信,顺序为从小到大的循环的环形 + // 阶段二:处理Server内通信,Server内的卡间进行fullmesh通信,同时需要将阶段一的数据传递给其他设备 + __aicore__ inline void InitTargetRank() + { + // 阶段一:server间的target rank, 此处表示数据最终的targetRank,并非直接发送的目标 + int32_t startRankId = blockIdx * rankNumPerCore; + targetRankNum = (rankSize - startRankId) < rankNumPerCore ? (rankSize - startRankId) : rankNumPerCore; + if (targetRankNum < 0) { + targetRankNum = 0; + } + + for (int i = 0; i < targetRankNum; i++) { + targetRank[i] = startRankId + i; + } + // 其余值设置为 invalid + for (int i = targetRankNum; i < MULTI_RANK_SIZE; i++) { + targetRank[i] = INVALID_RANK; + } + } + + __aicore__ inline void InitDataSlice() + { + // 生产者负责搬运本rank的输入数据至共享内存,input-->share + if (blockIdx < coreNumWithin) { + writeGt.SetGlobalBuffer((__gm__ T *)(shareAddrs[rank] + IPC_DATA_OFFSET)); + } + } + + __aicore__ inline void ProcessWithinServer() + { + if (blockIdx < coreNumWithin) { + InputToShareSlice(); + ShareToShareSlice(); + } + } + + __aicore__ inline void InputToShareSlice() + { + if (blockIdx > 0) { + return; + } + // 将本卡在Server内发送的input数据拷贝到本卡的共享内存对应位置 + int targetRankId = rank; + int32_t targetServerId = targetRankId / localRankSize; + + int64_t datalen = this->len; + readGt = sendDataInputGt[0]; + CpGM2GMPingPong(datalen * sizeof(T), readGt, writeGt[queLen * targetRankId + STATUS_ENTRY_SIZE / sizeof(T)], + COPYONLY); // 预留一个flag偏移位置 + // printflag("CpGM2GMPingPong\n"); + + for (int i = 0; i < localRankSize; ++i) { + int32_t curServerRankId = serverId * localRankSize + i; + // printf("SetInner rank:%d coreIdx:%d magic:%d, curServerRankId:%d\n", rank, blockIdx, magic, + // curServerRankId); + sync.SetInnerFlag(magic, 1, curServerRankId, rank); + } + // AscendC::DumpTensor(writeGt[queLen * targetRankId + STATUS_ENTRY_SIZE / sizeof(T)], 338, datalen); + } + + __aicore__ inline void ShareToShareSlice() + { + // 从Server内其他卡的共享内存对应位置拷贝数据到本卡的output + if (blockIdx > 0) { + return; + } + // printflag("ShareToShareSlice\n"); + int64_t recvCount = this->len; + for (int i = 0; i < localRankSize; ++i) { + int32_t targetRankId = serverId * localRankSize + i; + // printf("WaitInner rank:%d coreIdx:%d magic:%d, targetRankId:%d\n", rank, blockIdx, magic, targetRankId); + sync.WaitInnerFlag(magic, 1, rank, targetRankId); + for (int j = 0; j < serverNum; ++j) { + int32_t serverTarRankId = j * localRankSize + i; // 对应为targetRankId的同号卡 + remoteGt.SetGlobalBuffer((__gm__ T *)(shareAddrs[targetRankId] + IPC_DATA_OFFSET + + serverTarRankId * queSize + + STATUS_ENTRY_SIZE)); // 该rank上的第server块 + // PRINTF("[2ShareToShareSlice] rank:%d, blockId:%d, targetRankId:%d, serverTarRankId:%d", rank, + // blockIdx, targetRankId, serverTarRankId); AscendC::DumpTensor(remoteGt, 362, recvCount); + CpGM2GMPingPong(recvCount * sizeof(T), remoteGt, recvDataOutputGt[serverTarRankId * this->len], + COPYONLY); + } + } + } + + __aicore__ inline void AssembleSendData() + { + pipe.InitBuffer(tokenPerExpertDataBuf, tokenPerExpertDataAlignLen); + pipe.InitBuffer(sendDataBuf, sendDataAlignLen); + pipe.InitBuffer(sendDataOffsetBuf, sendDataOffsetAlignLen); + + __ubuf__ int32_t *tokenPerExpertUB = (__ubuf__ int32_t *)get_imm(96); + CpGM2UB(tokenPerExpertUB, (__gm__ int32_t *)tokenPerExpertDataInputGt.GetPhyAddr(), tokenPerExpertDataAlignLen); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + __ubuf__ T *sendDataOffsetUB = (__ubuf__ T *)get_imm(96 + tokenPerExpertDataAlignLen); + __ubuf__ T *sendDataUB = (__ubuf__ T *)get_imm(96 + tokenPerExpertDataAlignLen + sendDataOffsetAlignLen); + + int prefixSum = 0; + for (int i = 0; i < numExperts; ++i) { + int numTokensExpert = tokenPerExpertUB[i]; + sendDataUB[i * sendPerGroup] = numTokensExpert; + sendDataUB[i * sendPerGroup + 1] = prefixSum; + sendDataUB[i * sendPerGroup + 2] = numTokens; + sendDataOffsetUB[i] = prefixSum; + + prefixSum += numTokensExpert; + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + CpUB2GM((__gm__ T *)sendDataInputGt.GetPhyAddr(), sendDataUB, sendDataAlignLen); + CpUB2GM((__gm__ T *)sendDataOffsetOutputGt.GetPhyAddr(), sendDataOffsetUB, sendDataOffsetAlignLen); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + + __aicore__ inline void ProcessBetweenServer() + { + InputToWindowOut(); + // printflag("AfterInputToWindowOut\n") + ConstructBatchWriteInfo(); + // printflag("ConstructBatchWriteInfo\n") + SyncAll(); + SendRdma(); + // printflag("SendRdma\n") + WaitRdma(); + // printflag("WaitRdma\n") + SyncAll(); + // printflag("before WindowInToOutput\n") + WindowInToOutput(); + // printflag("WindowInToOutput\n") + } + + // 从input将数据拷贝到windowOutTensor,供RDMA进行发送 + __aicore__ inline void InputToWindowOut() + { + /* statusFlag 和 dataFlag 为int32_t,各自占用8B中的前4Bytes + --------------------------------------------------------------------------------------------------------------- + |8B pads|flag 8B|source 4B|target 4B|datalen 4B|4B pads| Data (datalen * sizeof(T)) | flag 8B | 24B pads | + --------------------------------------------------------------------------------------------------------------- + */ + if (blockIdx > 1) { + return; + } + int32_t targetRankId = 0; + if (blockIdx == 0) { + // targetRankId = rank; + return; // 同server的不搬运 + } else { // blockIdx=1 + targetRankId = (1 - serverId) * localRankSize + localRank; // 2个server的计算方式,求对端同号卡rankid + } + int32_t targetServerId = targetRankId / localRankSize; + + int64_t datalen = this->len; + readGt = sendDataInputGt[0]; // 读取全部数据 + + // 计算各个位置的offset,in bytes + int64_t statusEntryOffset = queSize * targetRankId; + int64_t statusFlagOffset = statusEntryOffset + FLAG_OFFSET; + int64_t sourceRankIdOffset = statusEntryOffset + SOURCE_RANK_OFFSET; + int64_t destRankIdOffset = statusEntryOffset + DEST_RANK_OFFSET; + int64_t dataLenOffset = statusEntryOffset + DATALEN_OFFSET; + int64_t dataOffset = statusEntryOffset + STATUS_ENTRY_SIZE; + int64_t dataFlagOffset = dataOffset + datalen * sizeof(T); + CpGM2GMPingPong(datalen * sizeof(T), readGt, windowOutTensor_[dataOffset / sizeof(T)], COPYONLY); + // printflag("enter2 InputToWindowOut\n") + + windowOutstatusTensor_(statusFlagOffset / sizeof(int32_t)) = FLAG_VALUE; + windowOutstatusTensor_(sourceRankIdOffset / sizeof(int32_t)) = rank; + windowOutstatusTensor_(destRankIdOffset / sizeof(int32_t)) = targetRankId; + windowOutstatusTensor_(dataLenOffset / sizeof(int32_t)) = (int32_t)datalen; + DataCacheCleanAndInvalid( + windowOutstatusTensor_[(statusEntryOffset / sizeof(int32_t))]); + windowOutstatusTensor_(dataFlagOffset / sizeof(int32_t)) = FLAG_VALUE; + DataCacheCleanAndInvalid( + windowOutstatusTensor_[(dataFlagOffset / sizeof(int32_t))]); + // PRINTF("##### rank:%d, blockId:%d, statusFlagOffset: %ld, windowOutstatusTensor_: %d, winOutAddr: %p \n", + // rank, blockIdx, + // statusFlagOffset, windowOutstatusTensor_.GetValue(statusFlagOffset / sizeof(int32_t)), + // windowOutstatusTensor_[statusFlagOffset].GetPhyAddr()); + // PRINTF("#####2 rank:%d, blockId:%d, statusEntryOffset: %ld, statusFlagOffset: %ld, sourceRankIdOffset: %ld, + // destRankIdOffset: %ld, dataLenOffset:%ld, dataOffset: %ld, dataFlagOffset: %ld\n", + // rank, blockIdx, statusEntryOffset, statusFlagOffset, sourceRankIdOffset, sourceRankIdOffset, + // destRankIdOffset, dataLenOffset, dataOffset, dataFlagOffset); + // AscendC::DumpTensor(windowOutstatusTensor_[statusEntryOffset / sizeof(int32_t)], 495, 8); + // printflag("enter3 InputToWindowOut\n") + } + + // 创建RDMA使用的batch write信息 + __aicore__ inline void ConstructBatchWriteInfo() + { + if (targetRankNum == 0 || blockIdx > 0) { + return; + } + + LocalTensor batchWriteU32Tensor_ = batchWriteInfoBuf_.Get(); + LocalTensor batchWriteU64Tensor_ = batchWriteInfoBuf_.Get(); + uint32_t batchWriteDataType = static_cast(AscendC::HcclDataType::HCCL_DATA_TYPE_INT8); + SyncFunc(); + + int32_t targetRankId = (1 - serverId) * localRankSize + localRank; // 2个server的计算方式 + + int32_t targetServerId = targetRankId / localRankSize; + uint32_t sendToRankId = targetServerId * localRankSize + localRank; // 数据发送目标Server的同号卡rankId + + // 数据在目标GM中的位置,保证第一轮数据不相互覆盖 + uint32_t sendOffset = serverId * localRankSize + (targetRankId % localRankSize); + + int64_t datalen = this->len; + GM_ADDR localBuf = (__gm__ uint8_t *)(windowOutGM_ + IPC_DATA_OFFSET + targetRankId * queSize); + GM_ADDR remoteGM = (__gm__ uint8_t *)(shareAddrs[sendToRankId] + IPC_DATA_OFFSET + rank * queSize); + uint64_t batchWriteDataSize = datalen * sizeof(T) + 2 * STATUS_ENTRY_SIZE; // payload加前后共2个flag长度 + + batchWriteU64Tensor_(0 * U64_PER_ITEM + BW_MEB_OFFSET64_LOCAL_GM) = (uint64_t)localBuf; + batchWriteU64Tensor_(0 * U64_PER_ITEM + BW_MEB_OFFSET64_REMOTE_GM) = (uint64_t)remoteGM; + batchWriteU64Tensor_(0 * U64_PER_ITEM + BW_MEB_OFFSET64_DATA_SIZE) = batchWriteDataSize; + batchWriteU32Tensor_(0 * U32_PER_ITEM + BW_MEB_OFFSET32_DATA_TYPE) = batchWriteDataType; + batchWriteU32Tensor_(0 * U32_PER_ITEM + BW_MEB_OFFSET32_TARGET_RANK) = sendToRankId; + + SyncFunc(); + // AscendC::DumpTensor(batchWriteInfoTensor_, 544, rankSize * U32_PER_ITEM); + DataCopy(batchWriteInfoTensor_[0], batchWriteU32Tensor_, 1 * U32_PER_ITEM); + PipeBarrier(); + } + + __aicore__ inline void SendRdma() + { + if (blockIdx == 0) { + HcclHandle batchWrResult = hccl_.BatchWrite((GM_ADDR)batchWriteInfoTensor_.GetPhyAddr(), 1); + } + } + + __aicore__ inline void WaitRdma() + { + if (targetRankNum == 0 || blockIdx > 0) { + return; + } + + DataCopyExtParams copyFlagParams{1, static_cast(sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams padParams{false, 0, 0, 0}; + LocalTensor dataFlagLocal = tempBuf_.Get(); + SyncFunc(); + + int32_t targetRankId = (1 - serverId) * localRankSize + localRank; // 2个server的计算方式 + int32_t targetServerId = targetRankId / localRankSize; + + int64_t statusOffset = targetRankId * queSize + FLAG_OFFSET; + // PRINTF("===rank:%d, blockId:%d, tarRankId:%d, tarServerId:%d, statusOffset: %ld, value: %d, winIn:%p \n", + // rank, blockIdx, targetRankId, targetServerId, statusOffset, + // windowInstatusTensor_.GetValue(statusOffset / sizeof(int32_t)), windowInstatusTensor_[statusOffset / + // sizeof(int32_t)].GetPhyAddr()); + // AscendC::DumpTensor(windowInstatusTensor_[targetRankId * queSize / sizeof(int32_t)], 589, U32_STATUS_ENTRY); + + int64_t datalen = 0; + int32_t statusFlag = 0; + int32_t dataFlag = 0; + // int64_t systemCycleBefore = AscendC::GetSystemCycle(); // 调用Add指令前的cycle数 + while (statusFlag != FLAG_VALUE) { + DataCopy(statusTensor_[0], windowInstatusTensor_[targetRankId * queSize / sizeof(int32_t)], + U32_STATUS_ENTRY); + SyncFunc(); + statusFlag = statusTensor_(FLAG_OFFSET / sizeof(int32_t)); + datalen = statusTensor_(DATALEN_OFFSET / sizeof(int32_t)); + PipeBarrier(); + + // int64_t systemCycleAfter = AscendC::GetSystemCycle(); // 调用Add指令后的cycle数 + // if ((systemCycleAfter - systemCycleBefore) / 50 > 1000000) { + // PRINTF("[1statusFlag] rank:%d, blockId:%d, tarRankId:%d, tarServerId:%d, statusOffset: %ld, value: + // %d, winIn:%p \n", rank, blockIdx, targetRankId, targetServerId, statusOffset, + // windowInstatusTensor_.GetValue(statusOffset / sizeof(int32_t)), + // windowInstatusTensor_[statusOffset / sizeof(int32_t)].GetPhyAddr()); + // AscendC::DumpTensor(windowInstatusTensor_[targetRankId * queSize / sizeof(int32_t)], 608, + // U32_STATUS_ENTRY * 32); break; + // } + } + + // int64_t systemCycleBefore2 = AscendC::GetSystemCycle(); // 调用Add指令前的cycle数 + uint64_t dataFlagOffset = (targetRankId * queSize + datalen * sizeof(T) + STATUS_ENTRY_SIZE) / sizeof(int32_t); + while (dataFlag != FLAG_VALUE) { + DataCopyPad(dataFlagLocal, windowInstatusTensor_[dataFlagOffset], copyFlagParams, padParams); + SyncFunc(); + dataFlag = dataFlagLocal(0); + PipeBarrier(); + // PRINTF("===[dataFlag]rank:%d, blockId:%d, tarRankId:%d, tarServerId:%d, dataFlag:%d \n", + // rank, blockIdx, targetRankId, targetServerId, dataFlag); + + // int64_t systemCycleAfter2 = AscendC::GetSystemCycle(); // 调用Add指令后的cycle数 + // if ((systemCycleAfter2 - systemCycleBefore2) / 50 > 1000000) { + // PRINTF("[1dataFlag] rank:%d, blockId:%d, tarRankId:%d, tarServerId:%d, dataFlagOffset: %d, value: %d + // \n", rank, blockIdx, targetRankId, targetServerId, dataFlagOffset, + // windowInstatusTensor_.GetValue(dataFlagOffset)); + // AscendC::DumpTensor(windowInstatusTensor_[dataFlagOffset], 628, sizeof(int32_t)); + // break; + // } + } + windowInstatusTensor_(dataFlagOffset) = 0; + } + + // 从RDMA收到的windowInTensor将数据拷贝到output + __aicore__ inline void WindowInToOutput() + { + /* + ---------------------------------------------------------------------------- + | STATUS_ENTRY_SIZE | Data (datalen * sizeof(T)) | STATUS_ENTRY_SIZE | + ---------------------------------------------------------------------------- + */ + if (blockIdx > 0) { + return; + } + int32_t targetRankId = (1 - serverId) * localRankSize + localRank; // 2个server的计算方式 + int64_t recvCount = this->len; + uint64_t dataOffset = (targetRankId * queSize + STATUS_ENTRY_SIZE) / sizeof(T); + CpGM2GMPingPong(recvCount * sizeof(T), windowInTensor_[dataOffset], + recvDataOutputGt[targetRankId * this->len], COPYONLY); + // AscendC::DumpTensor(windowInTensor_[dataOffset], 646, recvCount); + } + + // 从recvData拆分数据并计算输出 + __aicore__ inline void SplitAndCalcData() + { + pipe.Reset(); + pipe.InitBuffer(tempBuf_, UB_ALIGN); // 存放临时的立即数 + pipe.InitBuffer(tempBuf2_, 5000 * UB_ALIGN); // MAX_BS <= 4096, 要能放下一个bs的数据 + pipe.InitBuffer(tempBuf3_, numExperts * UB_ALIGN); // 要能放numExpert个数据 + + if (blockIdx == 0) { + // printflag("before BuildTokenUniquePerServerData\n"); + BuildTokenUniquePerServerData(); + // printflag("after BuildTokenUniquePerServerData\n"); + } + if (blockIdx == 1) { + // printflag("before BuildTokenSeverIdxData\n"); + BuildTokenSeverIdxData(); + // printflag("after BuildTokenSeverIdxData\n"); + } + if (blockIdx == 2) { + // printflag("before BuildCountOuterData\n"); + BuildCountOuterData(); + // printflag("after BuildCountOuterData\n"); + } + if (blockIdx == 3) { + // printflag("before BuildEpRankTokenCntAndSrcDstData\n"); + BuildEpRankTokenCntAndSrcDstData(); + // printflag("after BuildEpRankTokenCntAndSrcDstData\n"); + } + if (blockIdx == 4) { + // printflag("before BuildExpandIdxData\n"); + BuildExpandIdxData(); + // printflag("after BuildExpandIdxData\n"); + } + if (blockIdx == 5) { + // printflag("before BuildOffsetInnerData\n"); + BuildOffsetInnerData(); + // printflag("after BuildOffsetInnerData\n"); + } + } + + __aicore__ inline void BuildTokenSeverIdxData() + { + // printflag("enter BuildTokenSeverIdxData\n"); + // 计算 tokenServerIdxOutputGT_ + LocalTensor tmpLt = tempBuf2_.Get(); + DataCopyExtParams copyParams{1, static_cast(MAX_BS * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams padParams{false, 0, 0, 0}; + + // offset + numTokensPerExpertLen + numTokensUniquePerServerLen + numTokensPerServerLen + tokenServerCntLen + int32_t curRankDataOffset = rank * len + numExperts + serverNum + MAX_BS * serverNum + MAX_BS; + // AscendC::DumpTensor(recvDataOutputGt[curRankDataOffset], 652, 16); + + AscendC::SetFlag(EVENT_ID0); // MTE2 waits for MTE3 + for (int i = 0; i < serverNum; ++i) { + int32_t recvOffset = curRankDataOffset + i * MAX_BS; // 每次从recvdata中拷贝 MAX_BS 个数 + // PRINTF("[BuildTokenSeverIdxData] rank:%d, blockIdx:%d, recvOffset:%d\n", rank, blockIdx, recvOffset); + + event_t eventId = EVENT_ID0; + AscendC::WaitFlag(eventId); + + DataCopyPad(tmpLt, recvDataOutputGt[recvOffset], copyParams, padParams); + + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + // AscendC::DumpTensor(tmpLt, 667, 16); + + int32_t tarOffset = i * MAX_BS; + DataCopyPad(tokenServerIdxOutputGT_[tarOffset], tmpLt, copyParams); + // AscendC::DumpTensor(tokenServerIdxOutputGT_[tarOffset], 671, 16); + + AscendC::SetFlag(eventId); + } + AscendC::WaitFlag(EVENT_ID0); // MTE2 waits for MTE3 + // AscendC::DumpTensor(tokenServerIdxOutputGT_, 677, 16); + } + + __aicore__ inline void BuildExpandIdxData() + { + // printflag("enter BuildExpandIdxData\n"); + LocalTensor tmpLt = tempBuf2_.Get(); + DataCopyExtParams copyParams{1, static_cast(MAX_BS * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams padParams{false, 0, 0, 0}; + + // 计算 expandIdxOutputGT_ , 对应于输入 tokenExpertIdx + // offset + numTokensPerExpertLen + numTokensUniquePerServerLen + numTokensPerServerLen + tokenServerCntLen + + // tokenServerIdxLen + int32_t curRankDataOffset = + rank * len + numExperts + serverNum + MAX_BS * serverNum + MAX_BS + MAX_BS * serverNum; + // AscendC::DumpTensor(recvDataOutputGt[curRankDataOffset], 725, 16); + AscendC::SetFlag(EVENT_ID0); // MTE2 waits for MTE3 + for (int i = 0; i < numExperts; ++i) { + int32_t recvOffset = curRankDataOffset + i * MAX_BS; // 每次从recvdata中拷贝 MAX_BS 个数 + // PRINTF("[BuildExpandIdxData] rank:%d, blockIdx:%d, recvOffset:%d\n", rank, blockIdx, recvOffset); + + event_t eventId = EVENT_ID0; + AscendC::WaitFlag(eventId); + + DataCopyPad(tmpLt, recvDataOutputGt[recvOffset], copyParams, padParams); + + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + // AscendC::DumpTensor(tmpLt, 740, 16); + + int32_t tarOffset = i * MAX_BS; + DataCopyPad(expandIdxOutputGT_[tarOffset], tmpLt, copyParams); + // AscendC::DumpTensor(expandIdxOutputGT_[tarOffset], 744, 16); + + AscendC::SetFlag(eventId); + } + AscendC::WaitFlag(EVENT_ID0); // MTE2 waits for MTE3 + // AscendC::DumpTensor(expandIdxOutputGT_, 750, 16); + } + + __aicore__ inline void GetEpRankSumCnt(int32_t srcRank, LocalTensor &epTokenCntLt) + { + DataCopyExtParams copyParams{1, static_cast(numExperts * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams padParams{false, 0, 0, 0}; + + SyncFunc(); + + int32_t epTokenCntOffset = srcRank * len; + // AscendC::DumpTensor(recvDataOutputGt[epTokenCntOffset], 652, 16); + DataCopyPad(epTokenCntLt, recvDataOutputGt[epTokenCntOffset], copyParams, padParams); + + SyncFunc(); + + // 假设epTokenCntGt为 [2,2,2,2] --> 起始前缀和 [0,2,4,6] + int32_t preCnt = 0; + int32_t curVal = 0; + uint32_t localServerExpNum = numExperts / rankSize * localRankSize; + for (int32_t i = 0; i < numExperts; ++i) { + if (i % localServerExpNum == 0) { + preCnt = 0; + } + curVal = epTokenCntLt(i); + pipe_barrier(PIPE_ALL); + epTokenCntLt(i) = preCnt; // 设置为前一个元素的前缀和 + pipe_barrier(PIPE_ALL); + preCnt += curVal; + } + // AscendC::DumpTensor(epTokenCntLt, 748, 16); + } + + __aicore__ inline void BuildOffsetInnerData() + { + LocalTensor tmpLt = tempBuf2_.Get(); + LocalTensor epTokenCntLt = tempBuf3_.Get(); + DataCopyExtParams copyParams{1, static_cast(numExperts * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams padParams{false, 0, 0, 0}; + + // 计算 offsetInnerOutputGT_ , 为全局的 expandIdx + // shape[num_rank * max_bs, expertNum] value: inner_offset + AscendC::SetFlag(EVENT_ID0); // MTE2 waits for MTE3 + for (int srcRank = 0; srcRank < rankSize; ++srcRank) { + // 1.获取本卡发给每个expert的token个数起始前缀和 + GetEpRankSumCnt(srcRank, epTokenCntLt); + SyncFunc(); + + int32_t dataOffset = + srcRank * len + numExperts + serverNum + MAX_BS * serverNum + MAX_BS + MAX_BS * serverNum; + for (int tokId = 0; tokId < MAX_BS; ++tokId) { + int32_t recvOffset = dataOffset + tokId * numExperts; // 每次从recvdata中拷贝 numExperts 个数 + // PRINTF("[BuildOffsetInnerData] rank:%d, blockIdx:%d, recvOffset:%d\n", rank, blockIdx, recvOffset); + + event_t eventId = EVENT_ID0; + AscendC::WaitFlag(eventId); + + // 2.每个token发到expert的顺序, 当前token的expand_idx + DataCopyPad(tmpLt, recvDataOutputGt[recvOffset], copyParams, padParams); + SyncFunc(); + + // 3.求token在每个专家上的偏移 + // 遍历每个token的expand_idx, 如果不为-1,则加上 epTokenCntLt 中对应专家的值(对应列), 否则保持-1; + for (int32_t expId = 0; expId < numExperts; ++expId) { + int val = tmpLt(expId); + if (val == -1) { + continue; + } + val += epTokenCntLt(expId); + pipe_barrier(PIPE_ALL); + tmpLt(expId) = val; + pipe_barrier(PIPE_ALL); + } + + SyncFunc(); + // AscendC::DumpTensor(tmpLt, 793, 16); + + int32_t tarOffset = (srcRank * MAX_BS * numExperts) + tokId * numExperts; + DataCopyPad(offsetInnerOutputGT_[tarOffset], tmpLt, copyParams); + // AscendC::DumpTensor(offsetInnerOutputGT_[tarOffset], 797, 16); + + AscendC::SetFlag(eventId); + } + } + AscendC::WaitFlag(EVENT_ID0); // MTE2 waits for MTE3 + } + + __aicore__ inline void BuildCountOuterData() + { + // printflag("enter BuildCountOuterData\n"); + // 计算 countOuterOutputGT_ + LocalTensor tmpLt = tempBuf2_.Get(); + DataCopyExtParams copyParams{1, static_cast(MAX_BS * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams padParams{false, 0, 0, 0}; + + // offset + numTokensPerExpertLen + numTokensUniquePerServerLen + numTokensPerServerLen + int32_t curRankDataOffset = rank * len + numExperts + serverNum + MAX_BS * serverNum; + // PRINTF("[BuildCountOuterData] rank:%d, blockIdx:%d, curRankDataOffset:%d\n", rank, blockIdx, + // curRankDataOffset); + + DataCopyPad(tmpLt, recvDataOutputGt[curRankDataOffset], copyParams, padParams); + + SyncFunc(); + // AscendC::DumpTensor(tmpLt, 744, 16); + + DataCopyPad(countOuterOutputGT_, tmpLt, copyParams); + SyncFunc(); + // AscendC::DumpTensor(countOuterOutputGT_, 749, 16); + } + + __aicore__ inline void BuildTokenUniquePerServerData() + { + // printflag("enter BuildTokenUniquePerServerData\n"); + // 计算 tokensUniquePerServerOutputGT_ + LocalTensor tmpLt = tempBuf2_.Get(); + DataCopyExtParams copyParams{1, static_cast(serverNum * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams padParams{false, 0, 0, 0}; + + int32_t curRankDataOffset = rank * len + numExperts; // offset + numTokensPerExpertLen + // PRINTF("[BuildTokenUniquePerServerData] rank:%d, blockIdx:%d, curRankDataOffset:%d\n", rank, blockIdx, + // curRankDataOffset); + DataCopyPad(tmpLt, recvDataOutputGt[curRankDataOffset], copyParams, padParams); + + SyncFunc(); + + DataCopyPad(tokensUniquePerServerOutputGT_, tmpLt, copyParams); + SyncFunc(); + } + + __aicore__ inline void BuildEpRankTokenCntAndSrcDstData() + { + // printflag("enter BuildEpRankTokenCntAndSrcDstData\n"); + LocalTensor tmpLt = tempBuf2_.Get(); + + GlobalTensor gEpRankTokenCntGT_; + gEpRankTokenCntGT_.SetGlobalBuffer( + (__gm__ int32_t *)(tmpDataInput), + gNumTokensPerExpertAlignLen / sizeof(int32_t)); // tmpDataInput地址用作临时存数 + // 计算 epRankTokenCntOutputGT_ + DataCopyExtParams copyParams1{1, static_cast(numExperts * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams padParams1{false, 0, 0, 0}; + int32_t curRankDataOffset = rank * len; + // AscendC::DumpTensor(recvDataOutputGt[curRankDataOffset], 652, 16); + + AscendC::SetFlag(EVENT_ID0); // MTE2 waits for MTE3 + for (int i = 0; i < rankSize; ++i) { + int32_t recvOffset = i * len; // 每次从recvdata中拷贝 MAX_BS 个数 + // PRINTF("[BuildEpRankTokenCntAndSrcDstData1] rank:%d, blockIdx:%d, recvOffset:%d\n", rank, blockIdx, + // recvOffset); + + event_t eventId = EVENT_ID0; + AscendC::WaitFlag(eventId); + + DataCopyPad(tmpLt, recvDataOutputGt[recvOffset], copyParams1, padParams1); + + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + // AscendC::DumpTensor(tmpLt, 667, 16); + + int32_t tarOffset = i * numExperts; + DataCopyPad(gEpRankTokenCntGT_[tarOffset], tmpLt, copyParams1); + // AscendC::DumpTensor(tokenServerIdxOutputGT_[tarOffset], 671, 16); + + AscendC::SetFlag(eventId); + } + AscendC::WaitFlag(EVENT_ID0); // MTE2 waits for MTE3 + + SyncFunc(); + DataCopyExtParams copyParams2{1, static_cast(1 * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams padParams2{false, 0, 0, 0}; + AscendC::SetFlag(EVENT_ID0); // MTE2 waits for MTE3 + // shape[rankSize, numExperts] --> shape[numExperts, rankSize] value: cnt + for (int srcRank = 0; srcRank < rankSize; ++srcRank) { + for (int curExp = 0; curExp < numExperts; ++curExp) { + int32_t inOffset = srcRank * numExperts + curExp; // 只拷贝一个值 + // PRINTF("[BuildEpRankTokenCntAndSrcDstData2] rank:%d, blockIdx:%d, srcRank:%d, curExp:%d, inOffset:%d + // \n", + // rank, blockIdx, srcRank, curExp, inOffset); + + event_t eventId = EVENT_ID0; + AscendC::WaitFlag(eventId); + + DataCopyPad(tmpLt, gEpRankTokenCntGT_[inOffset], copyParams2, padParams2); + + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + // AscendC::DumpTensor(tmpLt, 667, 16); + + int32_t outOffset = curExp * rankSize + srcRank; + DataCopyPad(epRankTokenCntOutputGT_[outOffset], tmpLt, copyParams2); + // AscendC::DumpTensor(epRankTokenCntOutputGT_[tarOffset], 671, 16); + + AscendC::SetFlag(eventId); + } + } + AscendC::WaitFlag(EVENT_ID0); // MTE2 waits for MTE3 + SyncFunc(); + + // 计算 localEpTokenCntOutputGT_ , shape[localExperts, rankSize] value: sumCnt 前缀和 + int32_t localExpertNum = numExperts / rankSize; + int32_t preCnt = 0; + AscendC::SetFlag(EVENT_ID0); // MTE2 waits for MTE3 + for (int i = 0; i < localExpertNum; ++i) { + for (int j = 0; j < rankSize; ++j) { + int32_t inOffset = (rank * localExpertNum + i) * rankSize + j; // 拷贝当前专家的1个值,对应不同rank来的 + + event_t eventId = EVENT_ID0; + AscendC::WaitFlag(eventId); + + DataCopyPad(tmpLt, epRankTokenCntOutputGT_[inOffset], copyParams2, padParams2); + + SyncFunc(); + int32_t cnt = tmpLt(0); + pipe_barrier(PIPE_ALL); + preCnt += cnt; + tmpLt(0) = preCnt; + pipe_barrier(PIPE_ALL); + // PRINTF("[BuildEpRankTokenCntAndSrcDstData3] rank:%d, blockIdx:%d, i:%d, j:%d, inOffset:%d, cnt:%d, + // preCnt:%d \n", + // rank, blockIdx, i, j, inOffset, cnt, preCnt); + + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + // AscendC::DumpTensor(tmpLt, 667, 16); + + int32_t outOffset = i * rankSize + j; + DataCopyPad(localEpTokenCntOutputGT_[outOffset], tmpLt, copyParams2); + // AscendC::DumpTensor(epRankTokenCntOutputGT_[tarOffset], 671, 16); + + AscendC::SetFlag(eventId); + } + } + AscendC::WaitFlag(EVENT_ID0); // MTE2 waits for MTE3 + SyncFunc(); + + GlobalTensor gExpertMaxBsSrcGT_; + gExpertMaxBsSrcGT_.SetGlobalBuffer( + (__gm__ int32_t *)(tmpDataInput + (gNumTokensPerExpertAlignLen) / sizeof(int32_t)), + gExpertMaxBsSrcOffsetAlignLen / sizeof(int32_t)); // sendDataInput地址用作临时存数 + // 计算 gExpertMaxBsSrcGT_ + DataCopyExtParams copyParams3{1, static_cast(MAX_BS * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams padParams3{false, 0, 0, 0}; + + AscendC::SetFlag(EVENT_ID0); // MTE2 waits for MTE3 + for (int i = 0; i < rankSize; ++i) { + int32_t dataOffset = i * len + numExperts + serverNum + MAX_BS * serverNum + MAX_BS + MAX_BS * serverNum + + MAX_BS * numExperts; + for (int j = 0; j < numExperts; ++j) { + int32_t recvOffset = dataOffset + j * MAX_BS; // 每次从recvdata中拷贝 MAX_BS 个数 + // PRINTF("[BuildEpRankTokenCntAndSrcDstData1] rank:%d, blockIdx:%d, recvOffset:%d\n", rank, blockIdx, + // recvOffset); + + event_t eventId = EVENT_ID0; + AscendC::WaitFlag(eventId); + + DataCopyPad(tmpLt, recvDataOutputGt[recvOffset], copyParams3, padParams3); + + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + // AscendC::DumpTensor(tmpLt, 667, 16); + + int32_t tarOffset = (i * numExperts * MAX_BS) + j * MAX_BS; + DataCopyPad(gExpertMaxBsSrcGT_[tarOffset], tmpLt, copyParams3); + // if (i == 0) { + // AscendC::DumpTensor(gExpertMaxBsSrcGT_[tarOffset], 1025, 16); + // } + + AscendC::SetFlag(eventId); + } + } + AscendC::WaitFlag(EVENT_ID0); // MTE2 waits for MTE3 + SyncFunc(); + + /** 计算 srcOffsetRankTokenIdxOutputGT_ / dstOffsetRankTokenIdxOutputGT_ + * shape[num_expert, num_rank, max_bs] value: src_offset/dst_offset <--- shape[num_rank, num_expert, max_bs] + */ + AscendC::SetFlag(EVENT_ID0); // MTE2 waits for MTE3 + DataCopyExtParams copyParams4{1, static_cast(1 * sizeof(int32_t)), 0, 0, 0}; + DataCopyPadExtParams padParams4{false, 0, 0, 0}; + int32_t dstOffsetStart = 0; + LocalTensor dstOffsetLt = tempBuf_.Get(); // 存立即数的buf + + for (int expId = 0; expId < numExperts; ++expId) { + if (expId % localExpertNum == 0) { + dstOffsetStart = 0; // 每次所属rank递增后,计算desOffset的起始位置需要重置为0 + } + for (int srcRank = 0; srcRank < rankSize; ++srcRank) { + DataCopyPad(tmpLt, epRankTokenCntOutputGT_[expId * rankSize + srcRank], copyParams4, + padParams4); // 只拷贝一个数 + SyncFunc(); + int32_t validTokenCnt = tmpLt(0); + pipe_barrier(PIPE_ALL); + + for (int tokId = 0; tokId < MAX_BS; ++tokId) { + event_t eventId = EVENT_ID0; + AscendC::WaitFlag(eventId); + + SyncFunc(); // 复用tmpLt,加一个同步 + int32_t inIdx = srcRank * numExperts * MAX_BS + expId * MAX_BS + tokId; + DataCopyPad(tmpLt, gExpertMaxBsSrcGT_[inIdx], copyParams4, padParams4); // 只拷贝一个数 + SyncFunc(); + int32_t srcOffset = tmpLt(0); + pipe_barrier(PIPE_ALL); + + SyncFunc(); + int32_t outIdx = expId * rankSize * MAX_BS + srcRank * MAX_BS + tokId; + DataCopyPad(srcOffsetRankTokenIdxOutputGT_[outIdx], tmpLt, copyParams4); + // AscendC::DumpTensor(tmpLt, 667, 16); + + if (tokId < validTokenCnt) { + dstOffsetLt(0) = dstOffsetStart; + pipe_barrier(PIPE_ALL); + dstOffsetStart++; // 有效token,写入当前rank的output目的偏移位置需要递增 + } else { + dstOffsetLt(0) = -1; + pipe_barrier(PIPE_ALL); + } + + // if (srcOffset == 4096) { + // PRINTF("[BuildEpRankTokenCntAndSrcDstData4] rank:%d, blockIdx:%d, expId:%d, srcRank:%d, + // tokId:%d, \ + // inIdx:%d, srcOffset:%d, validTokenCnt:%d, outIdx:%d, dstOffsetStart:%d\n", + // rank, blockIdx, expId, srcRank, tokId, inIdx, srcOffset, validTokenCnt, outIdx, + // dstOffsetStart); + // } + + SyncFunc(); + + DataCopyPad(dstOffsetRankTokenIdxOutputGT_[outIdx], dstOffsetLt, copyParams4); + // AscendC::DumpTensor(dstOffsetRankTokenIdxOutputGT_[outIdx], 671, 16); + + AscendC::SetFlag(eventId); + } + } + } + AscendC::WaitFlag(EVENT_ID0); // MTE2 waits for MTE3 + SyncFunc(); + } + + GlobalTensor sendDataInputGt; + GlobalTensor tokenPerExpertDataInputGt; + GlobalTensor tmpDataInputGt; + GlobalTensor sendDataOffsetOutputGt; + GlobalTensor recvDataOutputGt; + GlobalTensor readGt; + GlobalTensor writeGt; + GlobalTensor remoteGt; + + __gm__ T *sendDataInput; + __gm__ int *tokenPerExpertDataInput; + __gm__ int *tmpDataInput; + __gm__ T *sendDataOffsetOutput; + __gm__ T *recvDataOutput; + + int64_t queLen; + int64_t queSize; + int64_t queElemLen; // 共享内存队列里每个元素大小(以sizeof(T)计) + + int64_t coreNumBetween; // 分层通信第一阶段,Server间通信使用的核数 + int64_t coreNumWithin; // 分层通信第二阶段,Server内通信使用的核数 + int32_t rankNumPerCore; // 每个核负责的rank数 + + // RDMA相关变量 + int32_t serverNum; // Server数量 + int32_t serverId; // 本卡所属的server ID + int32_t targetRank[MULTI_RANK_SIZE]; // 当前核心跨Server发送数据的目标rank Id,即数据最终的目标rank + int32_t targetRankNum; // 当前核心跨Server发送数据的目标rank Id的数量,小于等于MULTI_RANK_SIZE + int64_t perRankDataNum; + + int rank; + int rankSize; + int localRank = 0; + int localRankSize = 0; // 在910A5中,表示一块板子上使用的卡数,在910B上表示单机内卡数。 + int xRankSize = 0; + int yRankSize = 0; + int xRankIdx = 0; + int yRankIdx = 0; + uint32_t extraFlag; + int root; + int sendPerGroup = 3; + int topkNum; + int64_t numExperts; + int64_t numTokens; + int64_t len; + int64_t magic; + int64_t blockIdx; // 当前aicore序号 + int64_t blockNum; // 当前rank的总aicore数 + int64_t timeout; + GM_ADDR scale; + GM_ADDR shareAddrs[CAM_MAX_RANK_SIZE]; // 共享内存地址列表 + __gm__ HcclOpResParam *winContext_[COMM_NUM]{nullptr, nullptr}; + TPipe pipe; // pipe工具类 + TBuf tBuf; + SyncCollectives sync; + + Hccl hccl_; + GM_ADDR windowInGM_; + GM_ADDR windowOutGM_; + GlobalTensor magicTensor_; // 用于存放magic,位于windowInstatusTensor_之前 + GlobalTensor batchWriteInfoTensor_; + GlobalTensor windowInstatusTensor_; // 用于rank间状态同步 + GlobalTensor windowInTensor_; + GlobalTensor windowOutstatusTensor_; // 用于rank间状态同步 + GlobalTensor windowOutTensor_; + TBuf<> batchWriteInfoBuf_; // 临时存放 batch write info + TBuf<> tempBuf_; + TBuf<> statusBuf_; + LocalTensor statusTensor_; // 临时存放statusFlag + TBuf<> tokenPerExpertDataBuf; + TBuf<> sendDataOffsetBuf; + TBuf<> sendDataBuf; + TBuf<> tempBuf2_; + TBuf<> tempBuf3_; + + uint32_t sendDataAlignLen{0}; + uint32_t tokenPerExpertDataAlignLen{0}; + uint32_t sendDataOffsetAlignLen{0}; + + uint32_t numTokensPerExpertAlignLen{0}; // 每个expert从本卡接收的token个数,对应一个rank的数据 + uint32_t gNumTokensPerExpertAlignLen{0}; // 全局,包含所有rank的 + uint32_t numTokensUniquePerServerAlignLen{0}; // 每个server从本卡接收的token个数(去重),对应一个rank的 + uint32_t gNumTokensUniquePerServerAlignLen{0}; // 全局,包含所有rank的 + uint32_t numTokensPerServerAlignLen{0}; // 本卡每个token发到每个server的个数(不去重), 对应一个rank的 + uint32_t gNumTokensPerServerAlignLen{0}; // 全局,包含所有rank的 + uint32_t tokenServerCntAlignLen{0}; // 本卡每个token发给多少个server, 对应一个rank的 + uint32_t gTokenServerCntAlignLen{0}; // 全局,包含所有rank的 + uint32_t tokenServerIdxAlignLen{0}; // 本卡每个token发送给各个server的顺序, 对应一个rank的 + uint32_t gTokenServerIdxAlignLen{0}; // 全局,包含所有rank的 + uint32_t tokenExpertIdxAlignLen{0}; // 每个token发到expert的顺序, 对应一个rank的 + uint32_t gTokenExpertIdxAlignLen{0}; // 全局,包含所有rank的 + uint32_t expertMaxBsSrcOffsetAlignLen{0}; // 每个expert从本卡接收的token的server内offset, 对应一个rank的 + uint32_t gExpertMaxBsSrcOffsetAlignLen{0}; // 全局,包含所有rank的 + uint32_t expertMaxBsOriOffsetAlignLen{0}; // 每个expert从本卡接收的token在原卡上的origin_offset, 对应一个rank的 + uint32_t gExpertMaxBsOriOffsetAlignLen{0}; // 全局,包含所有rank的 + + // GM_ADDR dataSpaceGT_; + __gm__ int32_t *epRankTokenCntOutputGM_; + GlobalTensor + tokenServerIdxOutputGT_; // token发送给对应server的token序号,-1表示没有,0-N表示序号 [bs, serverNum] + GlobalTensor + tokensUniquePerServerOutputGT_; // 当前rank发送给对应server的token个数 [serverNum] -> value:count数量 + GlobalTensor + epRankTokenCntOutputGT_; // 每个专家、从rank接收的token数量 [expert_num, rank_num] -> value:token_cnt + GlobalTensor localEpTokenCntOutputGT_; // 本卡每个专家、从rank接收的token数量 [local_expert_num, rank_num] + GlobalTensor srcOffsetRankTokenIdxOutputGT_; // 每个专家、从rank接收的token源端偏移 [expert_num, rank_num, + // token_idx] -> value:src_offset + GlobalTensor dstOffsetRankTokenIdxOutputGT_; // 每个专家、从rank接收的token目的端偏移 [expert_num, + // rank_num, token_idx] -> value:dst_offset + GlobalTensor countInnerOutputGT_; // token给各个server发送个数 弃用 + GlobalTensor offsetInnerOutputGT_; // 存放全局的expandIdx, [globalBs, expertNum] + GlobalTensor countOuterOutputGT_; // 每个token发送到的server数量 [bs] -> value:server数量 + GlobalTensor offsetOuterOutputGT_; // 每个token在server上的位次 同tokenServerIdxOutputGT_ + GlobalTensor + expandIdxOutputGT_; // 给同一专家的token个数 [bs * numExperts], topk_idx的同专家前缀和扩维到所有专家维度 +}; + +template +template +FORCE_INLINE_AICORE void NotifyDispatchA2::SetAtomic(int op) +{ + PipeBarrier(); + if (op != -1) { +#ifdef __DAV_C220_VEC__ + SetAtomicOpType(op); +#endif + } + PipeBarrier(); +} + +template +template +FORCE_INLINE_AICORE void NotifyDispatchA2::SetWaitEvent(event_t eventId) +{ + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); +} + +template +FORCE_INLINE_AICORE void NotifyDispatchA2::UnsetAtomic(int op) +{ + if (op != -1) { + AscendC::SetAtomicNone(); + } + PipeBarrier(); +} + +template +template +FORCE_INLINE_AICORE void NotifyDispatchA2::CpGM2GMPingPong(int64_t dataSizeRemain, + const GlobalTensor &sendDataInputGt, + const GlobalTensor &recvDataOutputGT, int op) +{ + // General case (U = K), input/output are the same, share one UB + // Only when conversion is needed (U->K), UB will be divided into two parts according to the ratio of + // sizeof(U):sizeof(K) and aligned to 32 bytes + constexpr int32_t ubBlockSize = UB_SINGLE_PING_PONG_ADD_SIZE_MAX; + constexpr int32_t ubAlignNum = ubBlockSize / (sizeof(K) + sizeof(U)) / UB_ALIGN_SIZE * UB_ALIGN_SIZE; + constexpr int32_t inputUbBlockSize = std::is_same_v ? ubBlockSize : ubAlignNum * sizeof(U); + constexpr int32_t outputUbBlockSize = std::is_same_v ? ubBlockSize : ubAlignNum * sizeof(K); + + __gm__ U *input = const_cast<__gm__ U *>(sendDataInputGt.GetPhyAddr()); + __gm__ K *output = const_cast<__gm__ K *>(recvDataOutputGT.GetPhyAddr()); + __ubuf__ U *inputUB[2] = {(__ubuf__ U *)(UB_HEAD_OFFSET), (__ubuf__ U *)(UB_MID_OFFSET)}; + __ubuf__ K *outputUB[2] = {(__ubuf__ K *)inputUB[0], (__ubuf__ K *)inputUB[1]}; + if constexpr (!std::is_same_v) { + outputUB[0] = (__ubuf__ K *)(inputUB[0] + inputUbBlockSize / sizeof(U)); + outputUB[1] = (__ubuf__ K *)(inputUB[1] + inputUbBlockSize / sizeof(U)); + } + int inputOffsetNum = 0; + int outputOffsetNum = 0; + if (dataSizeRemain <= 0) { + return; + } + + SetAtomic(op); + + AscendC::SetFlag(EVENT_ID0); // MTE2 waits for MTE3 + AscendC::SetFlag(EVENT_ID1); // MTE2 waits for MTE3 + for (int64_t i = 0; dataSizeRemain > 0; i++) { + // size and dataSizeRemain both refer to the output size + uint32_t size = dataSizeRemain > outputUbBlockSize ? outputUbBlockSize : dataSizeRemain; + event_t eventId = (i & 1) ? EVENT_ID0 : EVENT_ID1; + AscendC::WaitFlag(eventId); + CpGM2UB((i & 1) ? inputUB[0] : inputUB[1], input + inputOffsetNum, size / sizeof(K) * sizeof(U)); + if constexpr (!std::is_same_v) { + SetWaitEvent(eventId); + CastImpl((i & 1) ? outputUB[0] : outputUB[1], (i & 1) ? inputUB[0] : inputUB[1], RoundMode::CAST_NONE, + size / sizeof(K)); + SetWaitEvent(eventId); + } + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + CpUB2GM(output + outputOffsetNum, (i & 1) ? outputUB[0] : outputUB[1], size); + AscendC::SetFlag(eventId); + + dataSizeRemain -= size; + inputOffsetNum += (size / sizeof(K)); + outputOffsetNum += (size / sizeof(K)); + } + AscendC::WaitFlag(EVENT_ID0); // MTE2 waits for MTE3 + AscendC::WaitFlag(EVENT_ID1); // MTE2 waits for MTE3 + + AscendC::SetFlag(EVENT_ID3); // Scalar waits for MTE3 + AscendC::WaitFlag(EVENT_ID3); + + UnsetAtomic(op); + return; +} + +#endif /* ALL2ALL_V_LAYERED_RDMA_H */ diff --git a/csrc/deepep/ops2/op_kernel/notify_dispatch_tiling_a2.h b/csrc/deepep/ops2/op_kernel/notify_dispatch_tiling_a2.h new file mode 100644 index 00000000..5c614a58 --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/notify_dispatch_tiling_a2.h @@ -0,0 +1,25 @@ +#ifndef NOTIFY_DISPATCH_TILING_A2_H +#define NOTIFY_DISPATCH_TILING_A2_H + +#include "kernel_tiling/kernel_tiling.h" + +struct NotifyDispatchInfoA2 { + uint32_t rankSize; + uint32_t rankId; + uint32_t localRankSize; + uint32_t localRankId; + uint32_t sendCount; + uint32_t numTokens; + uint32_t topkNum; + uint32_t numExperts; + uint32_t aivNum; + uint64_t totalUbSize; +}; + +struct NotifyDispatchA2TilingData { + Mc2InitTiling mc2InitTiling; + Mc2CcTiling mc2CcTiling1; + NotifyDispatchInfoA2 notifyDispatchInfoA2; +}; + +#endif diff --git a/csrc/deepep/ops2/op_kernel/sync_collectives.h b/csrc/deepep/ops2/op_kernel/sync_collectives.h new file mode 100644 index 00000000..11b75cbb --- /dev/null +++ b/csrc/deepep/ops2/op_kernel/sync_collectives.h @@ -0,0 +1,433 @@ +#ifndef SYNC_COLLECTIVES_H +#define SYNC_COLLECTIVES_H + +#include "comm_args.h" + +using namespace AscendC; +using namespace Moe; + +// Synchronization flag occupies length +constexpr int64_t FLAG_UNIT_INT_NUM = 4; +// Memory size occupied by each synchronization unit (Bytes) +constexpr int64_t SYNC_UNIT_SIZE = FLAG_UNIT_INT_NUM * sizeof(int64_t); +// High-order offset when using magic as a comparison value +constexpr int64_t MAGIC_OFFSET = 32; +constexpr int64_t MAGIC_MASK = ~((1LL << MAGIC_OFFSET) - 1); + +class SyncCollectives +{ +public: + __aicore__ inline SyncCollectives() {} + + __aicore__ inline void Init(int rank, int rankSize, GM_ADDR *shareAddrs, TBuf &tBuf) + { + this->rank = rank; + this->rankSize = rankSize; + this->shareAddrs = shareAddrs; + this->blockIdx = GetBlockIdx(); + this->blockNum = GetBlockNum(); + // Length of a single indicator segment + segmentCount = GetBlockNum() * FLAG_UNIT_INT_NUM; + // Initialize the intra-card/inter-card synchronization address corresponding to the current core. + localSyncAddr = (__gm__ int64_t *)(shareAddrs[rank]); + basicSyncAddr = (__gm__ int64_t *)(shareAddrs[rank]) + GetBlockIdx() * FLAG_UNIT_INT_NUM; + blockOuterSyncAddr = (__gm__ int64_t *)(shareAddrs[rank]) + segmentCount + GetBlockIdx() * FLAG_UNIT_INT_NUM; + this->tBuf = tBuf; + } + + __aicore__ inline void SetSyncFlag(int32_t magic, int32_t value, int32_t eventID) + { + int64_t v = MergeMagicWithValue(magic, value); + SetFlag(localSyncAddr + eventID * FLAG_UNIT_INT_NUM, v); + } + + /** + * @brief Set the flag for the specified eventID of the designated card, with the value being a combination of magic + * and value. + * @param magic The operator batch, which will be combined into the high 32 bits of the flag value to be set. + * @param value The specific value to be set, which will be the low 32 bits of the flag value to be set. + * @param eventID Physically, it is an offset from the shared memory base address (requires scaling, not an absolute + * value). + * @param rank This rank is the rankId corresponding to the peerMems array in the CommArgs structure, not a global + * or local id. (Local is not applicable in the 91093 scenario, and global is not applicable in the 910B + * multi-machine scenario.) + */ + __aicore__ inline void SetSyncFlag(int32_t magic, int32_t value, int32_t eventID, int32_t rank) + { + int64_t v = MergeMagicWithValue(magic, value); + SetFlag((__gm__ int64_t *)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, v); + } + + __aicore__ inline int32_t CalEventIdByMulBlockNum(int32_t blockMultiplier, int32_t targetCoreId) + { + return (blockMultiplier * blockNum) + targetCoreId; + } + + /** + * @brief Wait for the flag of the specified eventID on the specified card to become a value + * composed of the combination of magic and value. + * @param magic The operator batch, which will be combined into the high 32 bits of the flag + * value to be wait. + * @param value The specific value to be wait, which will be the low 32 bits of the flag + * value to be wait. + * @param eventID Physically, it is an offset from the shared memory base address (requires + * scaling, not an absolute value). + * @param rank This rank is the rankId corresponding to the peerMems array in the CommArgs + * structure, not a global or local id. (Local is not applicable in the 91093 + * scenario, and global is not applicable in the 910B multi-machine scenario.) + */ + __aicore__ inline void WaitSyncFlag(int32_t magic, int32_t value, int32_t eventID, int32_t rank) + { + int64_t v = MergeMagicWithValue(magic, value); + WaitOneRankPartFlag((__gm__ int64_t *)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, 1, v); + } + + __aicore__ inline void WaitSyncFlag(int32_t magic, int32_t value, int32_t eventID) + { + int64_t v = MergeMagicWithValue(magic, value); + WaitOneRankPartFlag((__gm__ int64_t *)(shareAddrs[this->rank]) + eventID * FLAG_UNIT_INT_NUM, 1, v); + } + + /** + * @brief Wait for the flags starting from the specified eventID on the specified card to become + * a value composed of the combination of magic and value.
+ * Note: [eventID, eventID + flagNum) + */ + __aicore__ inline void WaitSyncFlag(int32_t magic, int32_t value, int32_t eventID, int32_t rank, int64_t flagNum) + { + int64_t v = MergeMagicWithValue(magic, value); + WaitOneRankPartFlag((__gm__ int64_t *)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, flagNum, v); + } + + // Set inner-card synchronization flag (memory A) + __aicore__ inline void SetInnerFlag(int32_t magic, int32_t eventID) + { + int64_t value = MergeMagicWithValue(magic, eventID); + SetFlag(basicSyncAddr, value); + } + + __aicore__ inline void SetInnerFlag(int32_t magic, int32_t eventID, int64_t setRank, int64_t setBlock) + { + int64_t value = MergeMagicWithValue(magic, eventID); + SetFlag((__gm__ int64_t *)(shareAddrs[setRank]) + setBlock * FLAG_UNIT_INT_NUM, value); + } + + // Wait for a single inner-card synchronization flag (memory A) + __aicore__ inline void WaitInnerFlag(int32_t magic, int32_t eventID, int64_t waitRank, int64_t waitBlock) + { + int64_t value = MergeMagicWithValue(magic, eventID); + WaitOneRankPartFlag((__gm__ int64_t *)(shareAddrs[waitRank]) + waitBlock * FLAG_UNIT_INT_NUM, 1, value); + } + + // Wait for all inner-card synchronization flags within the entire rank (memory A) + __aicore__ inline void WaitRankInnerFlag(int32_t magic, int32_t eventID, int64_t waitRank) + { + int64_t value = MergeMagicWithValue(magic, eventID); + WaitOneRankAllFlag((__gm__ int64_t *)(shareAddrs[waitRank]), value); + } + + // Check all inner-card synchronization flags within the entire rank (memory A) + __aicore__ inline bool CheckRankInnerFlag(int32_t magic, int32_t eventID, int64_t waitRank) + { + int64_t value = MergeMagicWithValue(magic, eventID); + return CheckOneRankAllFlag((__gm__ int64_t *)(shareAddrs[waitRank]), value); + } + + // Set inter-card synchronization flag (memory B) + __aicore__ inline void SetOuterFlag(int32_t magic, int32_t eventID) + { + int64_t value = MergeMagicWithValue(magic, eventID); + SetFlag(blockOuterSyncAddr, value); + } + + __aicore__ inline void SetOuterFlag(int32_t magic, int32_t eventID, int64_t setRank, int64_t setBlock) + { + __gm__ int64_t *flagAddr = GetOuterFlagAddr(setRank, setBlock); + int64_t value = MergeMagicWithValue(magic, eventID); + SetFlag(flagAddr, value); + } + + // Wait for a single inter-card synchronization flag (memory B) + __aicore__ inline void WaitOuterFlag(int32_t magic, int32_t eventID, int64_t waitRank, int64_t waitBlock) + { + int64_t value = MergeMagicWithValue(magic, eventID); + __gm__ int64_t *flagAddr = GetOuterFlagAddr(waitRank, waitBlock); + WaitOneRankPartFlag(flagAddr, 1, value); + } + + // Wait for all inter-card synchronization flags within the entire rank (memory B) + __aicore__ inline void WaitOneRankOuterFlag(int32_t magic, int32_t eventID, int64_t rank) + { + int64_t value = MergeMagicWithValue(magic, eventID); + __gm__ int64_t *flagAddr; + flagAddr = GetOuterFlagAddr(rank, 0); + WaitOneRankPartFlag(flagAddr, blockNum, value); + } + + // Wait for flagNum inter-card synchronization flags starting from startBlock for all ranks (memory B) + __aicore__ inline void WaitAllRankPartOuterFlag(int32_t magic, int32_t eventID, int64_t startBlock, int64_t flagNum) + { + int64_t value = MergeMagicWithValue(magic, eventID); + __gm__ int64_t *flagAddr; + int waitRank; + for (auto r = 0; r < rankSize; ++r) { + waitRank = (rank + r) % rankSize; // Offset reading of rank flags to prevent performance impact from + // concurrent copying by multiple cores + flagAddr = GetOuterFlagAddr(waitRank, startBlock); + WaitOneRankPartFlag(flagAddr, flagNum, value); + } + } + + // Check flagNum inter-card synchronization flags starting from startBlock for all ranks (memory B) + __aicore__ inline bool CheckAllRankPartOuterFlag(int32_t magic, int32_t eventID, int64_t startBlock, + int64_t flagNum) + { + int64_t value = MergeMagicWithValue(magic, eventID); + __gm__ int64_t *flagAddr; + int waitRank; + for (auto r = 0; r < rankSize; ++r) { + waitRank = (rank + r) % rankSize; // Offset reading of rank flags to prevent performance impact from + // concurrent copying by multiple cores + flagAddr = GetOuterFlagAddr(waitRank, startBlock); + if (!CheckOneRankPartFlag(flagAddr, flagNum, value)) { + return false; + } + } + return true; + } + + // Wait for all inter-card synchronization flags for all ranks, full rank synchronization (memory B) + __aicore__ inline void WaitAllRankOuterFlag(int32_t magic, int32_t eventID) + { + WaitAllRankPartOuterFlag(magic, eventID, 0, blockNum); + } + + // Check all inter-card synchronization flags for all ranks, full rank synchronization (memory B) + __aicore__ inline bool CheckAllRankOuterFlag(int32_t magic, int32_t eventID) + { + return CheckAllRankPartOuterFlag(magic, eventID, 0, blockNum); + } + + // Low-level interface, set synchronization flag + __aicore__ inline void SetFlag(__gm__ int64_t *setAddr, int64_t setValue) + { + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + GlobalTensor globalSet; + globalSet.SetGlobalBuffer(setAddr, FLAG_UNIT_INT_NUM); + LocalTensor localSet = tBuf.GetWithOffset(1, 0); + localSet.SetValue(0, setValue); + + // Copy global synchronization flag to local + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); // Wait for SetValue to complete + DataCopy(globalSet, localSet, FLAG_UNIT_INT_NUM); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); // Wait for UB->GM to complete + } + + // Low-level interface, wait for synchronization flag + __aicore__ inline void WaitFlag(__gm__ int64_t *waitAddr, int64_t waitValue) + { + WaitOneRankPartFlag(waitAddr, 1, waitValue); + } + + // Read a flag, return an immediate number + __aicore__ inline int64_t GetFlag(__gm__ int64_t *waitAddr) + { + GlobalTensor globalWait; + globalWait.SetGlobalBuffer(waitAddr, FLAG_UNIT_INT_NUM); + LocalTensor localWait = tBuf.GetWithOffset(1, 0); + // Copy global to local + DataCopy(localWait, globalWait, FLAG_UNIT_INT_NUM); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); // Wait for GM->UB + + int64_t res = localWait.GetValue(0); + return res; + } + + // Get multiple consecutive synchronization flags within a single card + __aicore__ inline void WaitOneRankPartOuterFlag(int32_t magic, int32_t eventID, int64_t waitRank, + int64_t startBlock, int64_t flagNum) + { + int64_t value = MergeMagicWithValue(magic, eventID); + __gm__ int64_t *flagAddr; + flagAddr = GetOuterFlagAddr(waitRank, startBlock); + WaitOneRankPartFlag(flagAddr, flagNum, value); + } + + // Get synchronization flag within a single card (memory A) + __aicore__ inline int64_t GetInnerFlag(int64_t waitRank, int64_t waitBlock) + { + return GetFlag((__gm__ int64_t *)(shareAddrs[waitRank]) + waitBlock * FLAG_UNIT_INT_NUM); + } + + __aicore__ inline int64_t GetOuterFlag(int64_t waitRank, int64_t waitBlock) + { + return GetFlag((__gm__ int64_t *)(shareAddrs[waitRank]) + segmentCount + waitBlock * FLAG_UNIT_INT_NUM); + } + + // In the rank Chunk Flag area, return success if the destRank chunk Flag value is 0, otherwise fail + __aicore__ inline int64_t GetChunkFlag(int64_t rank, int64_t destRank, int64_t magic, int64_t timeout) + { + int64_t value = MergeMagicWithValue(magic, 0); + int64_t status = GetChunkFlagValue( + (__gm__ int64_t *)(shareAddrs[rank]) + IPC_CHUNK_FLAG + destRank * FLAG_UNIT_INT_NUM, value, timeout); + return status; + } + + // Set the destRank chunk Flag value in the rank Chunk Flag area to value + __aicore__ inline void SetChunkFlag(int64_t rank, int64_t destRank, int64_t magic, int64_t eventId) + { + int64_t value = MergeMagicWithValue(magic, eventId); + SetFlag((__gm__ int64_t *)(shareAddrs[rank]) + IPC_CHUNK_FLAG + destRank * FLAG_UNIT_INT_NUM, value); + } + + __aicore__ inline int64_t GetChunkRecvLen(int64_t rank, int64_t destRank, int64_t magic, int64_t timeout) + { + int64_t len = + GetChunkFlagValue((__gm__ int64_t *)(shareAddrs[rank]) + IPC_CHUNK_FLAG + destRank * FLAG_UNIT_INT_NUM, 0, + timeout, true, magic); + return len; + } + +private: + __aicore__ inline int64_t MergeMagicWithValue(int32_t magic, int32_t value) + { + // Merge magic as the high bits and eventID as the low bits into a value for comparison + return (static_cast(static_cast(magic)) << MAGIC_OFFSET) | static_cast(value); + } + + __aicore__ inline __gm__ int64_t *GetInnerFlagAddr(int64_t flagRank, int64_t flagBlock) + { + return (__gm__ int64_t *)(shareAddrs[flagRank]) + flagBlock * FLAG_UNIT_INT_NUM; + } + + __aicore__ inline __gm__ int64_t *GetOuterFlagAddr(int64_t flagRank, int64_t flagBlock) + { + return (__gm__ int64_t *)(shareAddrs[flagRank]) + segmentCount + flagBlock * FLAG_UNIT_INT_NUM; + } + + // Wait for a part of synchronization flags within a rank + __aicore__ inline void WaitOneRankPartFlag(__gm__ int64_t *waitAddr, int64_t flagNum, int64_t checkValue) + { + GlobalTensor globalWait; + globalWait.SetGlobalBuffer(waitAddr, flagNum * FLAG_UNIT_INT_NUM); + LocalTensor localWait = tBuf.GetWithOffset(flagNum * FLAG_UNIT_INT_NUM, 0); + bool isSync = true; + int64_t checkedFlagNum = 0; + do { + // Copy global synchronization flags to local + DataCopy(localWait, globalWait[checkedFlagNum * FLAG_UNIT_INT_NUM], + (flagNum - checkedFlagNum) * FLAG_UNIT_INT_NUM); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); // Wait for GM->UB + + // Check if the synchronization flags are equal to checkValue + isSync = true; + int64_t remainToCheck = flagNum - checkedFlagNum; + for (auto i = 0; i < remainToCheck; ++i) { + // Continue waiting if any core has not reached the checkValue phase + int64_t v = localWait.GetValue(i * FLAG_UNIT_INT_NUM); + if ((v & MAGIC_MASK) != (checkValue & MAGIC_MASK) || v < checkValue) { + isSync = false; + checkedFlagNum += i; + break; + } + } + } while (!isSync); + } + + // Wait for all synchronization flags within a rank + __aicore__ inline void WaitOneRankAllFlag(__gm__ int64_t *waitAddr, int64_t checkValue) + { + WaitOneRankPartFlag(waitAddr, blockNum, checkValue); + } + + // Check partial synchronization flags within a rank, copy only once + __aicore__ inline bool CheckOneRankPartFlag(__gm__ int64_t *waitAddr, int64_t flagNum, int64_t checkValue) + { + GlobalTensor globalWait; + globalWait.SetGlobalBuffer(waitAddr, flagNum * FLAG_UNIT_INT_NUM); + LocalTensor localWait = tBuf.GetWithOffset(flagNum * FLAG_UNIT_INT_NUM, 0); + // Copy global synchronization flags to local + DataCopy(localWait, globalWait, flagNum * FLAG_UNIT_INT_NUM); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); // Wait for GM->UB + // Check if the synchronization flags are equal to checkValue + bool isSync = true; + for (auto i = 0; i < flagNum; ++i) { + // Continue waiting if any core has not reached the checkValue phase + int64_t v = localWait.GetValue(i * FLAG_UNIT_INT_NUM); + if ((v & MAGIC_MASK) != (checkValue & MAGIC_MASK) || v < checkValue) { + isSync = false; + break; + } + } + return isSync; + } + + __aicore__ inline int64_t GetChunkFlagValue(__gm__ int64_t *waitAddr, int64_t checkValue, int64_t timeout, + bool checkNonZero = false, int64_t magic = 0) + { + GlobalTensor globalWait; + globalWait.SetGlobalBuffer(waitAddr, FLAG_UNIT_INT_NUM); + LocalTensor localWait = tBuf.GetWithOffset(FLAG_UNIT_INT_NUM, 0); + bool isSync = true; + + int64_t waitTimes = 0; + int64_t v = 0; + + do { + // Copy global sync flag to local + DataCopy(localWait, globalWait[0], FLAG_UNIT_INT_NUM); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); // Wait for GM->UB + + isSync = true; + v = localWait.GetValue(0); + if (checkNonZero) { + // Non-zero check mode + if (((v & MAGIC_MASK) == (static_cast(magic) << MAGIC_OFFSET)) && (v & 0xFFFFFFFF)) { + return v & 0xFFFFFFFF; // Return lower 32 bits when non-zero + } + } else { + // Exact value check mode + if (v == checkValue) { + return WAIT_SUCCESS; + } + } + + isSync = false; + waitTimes++; + + if (timeout > INT64_MAX / MAX_WAIT_ROUND_UNIT || waitTimes >= (timeout * MAX_WAIT_ROUND_UNIT)) { + isSync = true; + return v; // Return the read flag value + } + } while (!isSync); + + return checkNonZero ? 0 : v; + } + + // Check all sync flags within a rank, copy only once + __aicore__ inline bool CheckOneRankAllFlag(__gm__ int64_t *waitAddr, int64_t checkValue) + { + return CheckOneRankPartFlag(waitAddr, blockNum, checkValue); + } + int rank; + int rankSize; + int blockIdx; + int blockNum; + GM_ADDR *shareAddrs; + int64_t segmentCount; // Length of a single sync flag segment (count in int64_t) + __gm__ int64_t *localSyncAddr; + __gm__ int64_t *basicSyncAddr; // Intra-card sync flag address for the current block + __gm__ int64_t *blockOuterSyncAddr; // Inter-card sync flag address for the current block + TBuf tBuf; +}; + +#endif // SYNC_COLLECTIVES_H diff --git a/csrc/deepep/ops2/scripts/help.info b/csrc/deepep/ops2/scripts/help.info new file mode 100644 index 00000000..de0069dc --- /dev/null +++ b/csrc/deepep/ops2/scripts/help.info @@ -0,0 +1 @@ + --install-path Install operator package to specific dir path diff --git a/csrc/deepep/ops2/scripts/install.sh b/csrc/deepep/ops2/scripts/install.sh new file mode 100755 index 00000000..e302e094 --- /dev/null +++ b/csrc/deepep/ops2/scripts/install.sh @@ -0,0 +1,317 @@ +#!/bin/bash + +vendor_name=customize +targetdir=/usr/local/Ascend/opp +target_custom=0 + +sourcedir=$PWD/packages +vendordir=vendors/$vendor_name + +QUIET="y" + +while true +do + case $1 in + --quiet) + QUIET="y" + shift + ;; + --install-path=*) + INSTALL_PATH=$(echo $1 | cut -d"=" -f2-) + INSTALL_PATH=${INSTALL_PATH%*/} + shift + ;; + --*) + shift + ;; + *) + break + ;; + esac +done + +log() { + cur_date=`date +"%Y-%m-%d %H:%M:%S"` + echo "[ops_custom] [$cur_date] "$1 +} + +if [ -n "${INSTALL_PATH}" ]; then + if [[ ! "${INSTALL_PATH}" = /* ]]; then + log "[ERROR] use absolute path for --install-path argument" + exit 1 + fi + if [ ! -d ${INSTALL_PATH} ]; then + mkdir ${INSTALL_PATH} >> /dev/null 2>&1 + if [ $? -ne 0 ]; then + log "[ERROR] create ${INSTALL_PATH} failed" + exit 1 + fi + fi + targetdir=${INSTALL_PATH} +elif [ -n "${ASCEND_CUSTOM_OPP_PATH}" ]; then + if [ ! -d ${ASCEND_CUSTOM_OPP_PATH} ]; then + mkdir -p ${ASCEND_CUSTOM_OPP_PATH} >> /dev/null 2>&1 + if [ $? -ne 0 ]; then + log "[ERROR] create ${ASCEND_CUSTOM_OPP_PATH} failed" + fi + fi + targetdir=${ASCEND_CUSTOM_OPP_PATH} +else + if [ "x${ASCEND_OPP_PATH}" == "x" ]; then + log "[ERROR] env ASCEND_OPP_PATH no exist" + exit 1 + fi + targetdir="${ASCEND_OPP_PATH}" +fi + +if [ ! -d $targetdir ];then + log "[ERROR] $targetdir no exist" + exit 1 +fi + +if [ ! -x $targetdir ] || [ ! -w $targetdir ] || [ ! -r $targetdir ];then + log "[WARNING] The directory $targetdir does not have sufficient permissions. \ + Please check and modify the folder permissions (e.g., using chmod), \ + or use the --install-path option to specify an installation path and \ + change the environment variable ASCEND_CUSTOM_OPP_PATH to the specified path." +fi + +upgrade() +{ + if [ ! -d ${sourcedir}/$vendordir/$1 ]; then + log "[INFO] no need to upgrade ops $1 files" + return 0 + fi + + if [ ! -d ${targetdir}/$vendordir/$1 ];then + log "[INFO] create ${targetdir}/$vendordir/$1." + mkdir -p ${targetdir}/$vendordir/$1 + if [ $? -ne 0 ];then + log "[ERROR] create ${targetdir}/$vendordir/$1 failed" + return 1 + fi + else + has_same_file=-1 + for file_a in ${sourcedir}/$vendordir/$1/*; do + file_b=${file_a##*/}; + if [ "ls ${targetdir}/$vendordir/$1" = "" ]; then + log "[INFO] ${targetdir}/$vendordir/$1 is empty !!" + return 1 + fi + grep -q $file_b <<<`ls ${targetdir}/$vendordir/$1`; + if [[ $? -eq 0 ]]; then + echo -n "${file_b} " + has_same_file=0 + fi + done + if [ 0 -eq $has_same_file ]; then + if test $QUIET = "n"; then + echo "[INFO]: has old version in ${targetdir}/$vendordir/$1, \ + you want to Overlay Installation , please enter:[o]; \ + or replace directory installation , please enter: [r]; \ + or not install , please enter:[n]." + + while true + do + read orn + if [ "$orn" = n ]; then + return 0 + elif [ "$orn" = m ]; then + break; + elif [ "$orn" = r ]; then + [ -n "${targetdir}/$vendordir/$1/" ] && rm -rf "${targetdir}/$vendordir/$1"/* + break; + else + log "[ERROR] input error, please input again!" + fi + done + fi + fi + log "[INFO] replace or merge old ops $1 files .g....." + fi + + log "copy new ops $1 files ......" + if [ -d ${targetdir}/$vendordir/$1/ ]; then + chmod -R +w "$targetdir/$vendordir/$1/" >/dev/null 2>&1 + fi + cp -rf ${sourcedir}/$vendordir/$1/* $targetdir/$vendordir/$1/ + if [ $? -ne 0 ];then + log "[ERROR] copy new $1 files failed" + return 1 + fi + + return 0 +} +upgrade_proto() +{ + if [ ! -f ${sourcedir}/$vendordir/custom.proto ]; then + log "[INFO] no need to upgrade custom.proto files" + return 0 + fi + if [ ! -d ${targetdir}/$vendordir/framework/caffe ];then + log "[INFO] create ${targetdir}/$vendordir/framework/caffe." + mkdir -p ${targetdir}/$vendordir/framework/caffe + if [ $? -ne 0 ];then + log "[ERROR] create ${targetdir}/$vendordir/framework/caffe failed" + return 1 + fi + else + if [ -f ${targetdir}/$vendordir/framework/caffe/custom.proto ]; then + # 有老版本,判断是否要覆盖式安装 + if test $QUIET = "n"; then + echo "[INFO] ${targetdir}/$vendordir/framework/caffe has old version"\ + "custom.proto file. Do you want to replace? [y/n] " + + while true + do + read yn + if [ "$yn" = n ]; then + return 0 + elif [ "$yn" = y ]; then + break; + else + log "[ERROR] input error, please input again!" + fi + done + fi + fi + log "[INFO] replace old caffe.proto files ......" + fi + chmod -R +w "$targetdir/$vendordir/framework/caffe/" >/dev/null 2>&1 + cp -rf ${sourcedir}/$vendordir/custom.proto ${targetdir}/$vendordir/framework/caffe/ + if [ $? -ne 0 ];then + log "[ERROR] copy new custom.proto failed" + return 1 + fi + log "[INFO] copy custom.proto success" + + return 0 +} + +upgrade_file() +{ + if [ ! -e ${sourcedir}/$vendordir/$1 ]; then + log "[INFO] no need to upgrade ops $1 file" + return 0 + fi + + log "copy new $1 files ......" + cp -f ${sourcedir}/$vendordir/$1 $targetdir/$vendordir/$1 + if [ $? -ne 0 ];then + log "[ERROR] copy new $1 file failed" + return 1 + fi + + return 0 +} + +delete_optiling_file() +{ + if [ ! -d ${targetdir}/vendors ];then + log "[INFO] $1 not exist, no need to uninstall" + return 0 + fi + sys_info=$(uname -m) + if [ ! -d ${sourcedir}/$vendordir/$1/ai_core/tbe/op_tiling/lib/linux/${sys_info} ];then + rm -rf ${sourcedir}/$vendordir/$1/ai_core/tbe/op_tiling/liboptiling.so + fi + return 0 +} + +log "[INFO] copy uninstall sh success" + +if [ ! -d ${targetdir}/vendors ];then + log "[INFO] create ${targetdir}/vendors." + mkdir -p ${targetdir}/vendors + if [ $? -ne 0 ];then + log "[ERROR] create ${targetdir}/vendors failed" + exit 1 + fi +fi +chmod u+w ${targetdir}/vendors + +log "[INFO] upgrade framework" +upgrade framework +if [ $? -ne 0 ];then + exit 1 +fi + +log "[INFO] upgrade op proto" +upgrade op_proto +if [ $? -ne 0 ];then + exit 1 +fi + +log "[INFO] upgrade op impl" +delete_optiling_file op_impl +upgrade op_impl +if [ $? -ne 0 ];then + exit 1 +fi + +log "[INFO] upgrade op api" +upgrade op_api +if [ $? -ne 0 ];then + exit 1 +fi + +log "[INFO] upgrade version.info" +upgrade_file version.info +if [ $? -ne 0 ];then + exit 1 +fi + +upgrade_proto +if [ $? -ne 0 ];then + exit 1 +fi + +# set the set_env.bash +if [ -n "${INSTALL_PATH}" ] && [ -d ${INSTALL_PATH} ]; then + _ASCEND_CUSTOM_OPP_PATH=${targetdir}/${vendordir} + bin_path="${_ASCEND_CUSTOM_OPP_PATH}/bin" + set_env_variable="#!/bin/bash\nexport ASCEND_CUSTOM_OPP_PATH=${_ASCEND_CUSTOM_OPP_PATH}:\${ASCEND_CUSTOM_OPP_PATH}\nexport LD_LIBRARY_PATH=${_ASCEND_CUSTOM_OPP_PATH}/op_api/lib/:\${LD_LIBRARY_PATH}" + if [ ! -d ${bin_path} ]; then + mkdir -p ${bin_path} >> /dev/null 2>&1 + if [ $? -ne 0 ]; then + log "[ERROR] create ${bin_path} failed" + exit 1 + fi + fi + echo -e ${set_env_variable} > ${bin_path}/set_env.bash + if [ $? -ne 0 ]; then + log "[ERROR] write ASCEND_CUSTOM_OPP_PATH to set_env.bash failed" + exit 1 + else + log "[INFO] using requirements: when custom module install finished or before you run the custom module, \ + execute the command [ source ${bin_path}/set_env.bash ] to set the environment path" + fi +else + _ASCEND_CUSTOM_OPP_PATH=${targetdir}/${vendordir} + config_file=${targetdir}/vendors/config.ini + if [ ! -f ${config_file} ]; then + touch ${config_file} + chmod 640 ${config_file} + echo "load_priority=$vendor_name" > ${config_file} + if [ $? -ne 0 ];then + log "[ERROR] echo load_priority failed" + exit 1 + fi + else + found_vendors="$(grep -w "load_priority" "$config_file" | cut --only-delimited -d"=" -f2-)" + found_vendor=$(echo $found_vendors | sed "s/\<$vendor_name\>//g" | tr ',' ' ') + vendor=$(echo $found_vendor | tr -s ' ' ',') + if [ "$vendor" != "" ]; then + sed -i "/load_priority=$found_vendors/s@load_priority=$found_vendors@load_priority=$vendor_name,$vendor@g" "$config_file" + fi + fi + log "[INFO] using requirements: when custom module install finished or before you run the custom module, \ + execute the command [ export LD_LIBRARY_PATH=${_ASCEND_CUSTOM_OPP_PATH}/op_api/lib/:\${LD_LIBRARY_PATH} ] to set the environment path" +fi + +if [ -d ${targetdir}/$vendordir/op_impl/cpu/aicpu_kernel/impl/ ]; then + chmod -R 440 ${targetdir}/$vendordir/op_impl/cpu/aicpu_kernel/impl/* >/dev/null 2>&1 +fi + +echo "SUCCESS" +exit 0 diff --git a/csrc/deepep/ops2/scripts/upgrade.sh b/csrc/deepep/ops2/scripts/upgrade.sh new file mode 100755 index 00000000..38a59139 --- /dev/null +++ b/csrc/deepep/ops2/scripts/upgrade.sh @@ -0,0 +1,144 @@ +#!/bin/bash + +vendor_name=customize +targetdir=/usr/local/Ascend/opp +target_custom=0 + +sourcedir=$PWD/packages +vendordir=vendors/$vendor_name + +log() { + cur_date=`date +"%Y-%m-%d %H:%M:%S"` + echo "[ops_custom] [$cur_date] "$1 +} + +if [[ "x${ASCEND_OPP_PATH}" == "x" ]];then + log "[ERROR] env ASCEND_OPP_PATH no exist" + exit 1 +fi + +targetdir=${ASCEND_OPP_PATH} + +if [ ! -d $targetdir ];then + log "[ERROR] $targetdir no exist" + exit 1 +fi + +if [ ! -x $targetdir ] || [ ! -w $targetdir ] || [ ! -r $targetdir ];then + log "[WARNING] The directory $targetdir does not have sufficient permissions. \ + Please check and modify the folder permissions (e.g., using chmod), \ + or use the --install-path option to specify an installation path and \ + change the environment variable ASCEND_CUSTOM_OPP_PATH to the specified path." +fi + +upgrade() +{ + if [ ! -d ${sourcedir}/$vendordir/$1 ]; then + log "[INFO] no need to upgrade ops $1 files" + return 0 + fi + + if [ ! -d ${targetdir}/$vendordir/$1 ];then + log "[INFO] create ${targetdir}/$vendordir/$1." + mkdir -p ${targetdir}/$vendordir/$1 + if [ $? -ne 0 ];then + log "[ERROR] create ${targetdir}/$vendordir/$1 failed" + return 1 + fi + else + vendor_installed_dir=$(ls "$targetdir/vendors" 2> /dev/null) + for i in $vendor_installed_dir;do + vendor_installed_file=$(ls "$vendor_installed_dir/$vendor_name/$i" 2> /dev/null) + if [ "$i" = "$vendor_name" ] && [ "$vendor_installed_file" != "" ]; then + echo "[INFO]: $vendor_name custom opp package has been installed on the path $vendor_installed_dir, \ + you want to Overlay Installation , please enter:[o]; \ + or replace directory installation , please enter: [r]; \ + or not install , please enter:[n]." + fi + while true + do + read mrn + if [ "$mrn" = m ]; then + break + elif [ "$mrn" = r ]; then + [ -n "$vendor_installed_file" ] && rm -rf "$vendor_installed_file" + break + elif [ "$mrn" = n ]; then + return 0 + else + log "[WARNING]: Input error, please input m or r or n to choose!" + fi + done + done + log "[INFO] replace old ops $1 files ......" + fi + + log "copy new ops $1 files ......" + cp -rf ${sourcedir}/$vendordir/$1/* $targetdir/$vendordir/$1/ + if [ $? -ne 0 ];then + log "[ERROR] copy new $1 files failed" + return 1 + fi + + return 0 +} + +upgrade_file() +{ + if [ ! -e ${sourcedir}/$vendordir/$1 ]; then + log "[INFO] no need to upgrade ops $1 file" + return 0 + fi + + log "copy new $1 files ......" + cp -f ${sourcedir}/$vendordir/$1 $targetdir/$vendordir/$1 + if [ $? -ne 0 ];then + log "[ERROR] copy new $1 file failed" + return 1 + fi + + return 0 +} + +log "[INFO] copy uninstall sh success" + +log "[INFO] upgrade framework" +upgrade framework +if [ $? -ne 0 ];then + exit 1 +fi + +log "[INFO] upgrade op proto" +upgrade op_proto +if [ $? -ne 0 ];then + exit 1 +fi + +log "[INFO] upgrade op impl" +upgrade op_impl +if [ $? -ne 0 ];then + exit 1 +fi + +log "[INFO] upgrade op api" +upgrade op_api +if [ $? -ne 0 ];then + exit 1 +fi + +log "[INFO] upgrade version.info" +upgrade_file version.info +if [ $? -ne 0 ];then + exit 1 +fi + +config_file=${targetdir}/vendors/config.ini +found_vendors="$(grep -w "load_priority" "$config_file" | cut --only-delimited -d"=" -f2-)" +found_vendor=$(echo $found_vendors | sed "s/\<$vendor_name\>//g" | tr ',' ' ') +vendor=$(echo $found_vendor | tr -s ' ' ',') +if [ "$vendor" != "" ]; then + sed -i "/load_priority=$found_vendors/s@load_priority=$found_vendors@load_priority=$vendor_name,$vendor@g" "$config_file" +fi + +echo "SUCCESS" +exit 0 diff --git a/csrc/deepep/ops2/utils/.DS_Store b/csrc/deepep/ops2/utils/.DS_Store new file mode 100644 index 00000000..cf38bc73 Binary files /dev/null and b/csrc/deepep/ops2/utils/.DS_Store differ diff --git a/csrc/deepep/ops2/utils/op_host/error_log.h b/csrc/deepep/ops2/utils/op_host/error_log.h new file mode 100644 index 00000000..d809a922 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_host/error_log.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Description: create log implementation file + * Author: Han Jiahui + * Create: 2025-05-21 + * Note: + * History: 2025-05-21 create log implementation file + */ +#ifndef OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ +#define OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ + +#include +#include "toolchain/slog.h" + +#define OP_LOGI(opname, ...) +#define OP_LOGW(opname, ...) \ + printf("[WARN]" __VA_ARGS__); \ + printf("\n") +#define OP_LOGE_WITHOUT_REPORT(opname, ...) \ + printf("[ERRORx]" __VA_ARGS__); \ + printf("\n") +#define OP_LOGE(opname, ...) \ + printf("[ERROR]" __VA_ARGS__); \ + printf("\n") +#define OP_LOGD(opname, ...) + +namespace optiling { + +#define VECTOR_INNER_ERR_REPORT_TILIING(op_name, err_msg, ...) \ + do { \ + OP_LOGE_WITHOUT_REPORT(op_name, err_msg, ##__VA_ARGS__); \ + } while (0) + +#define OP_TILING_CHECK(cond, log_func, expr) \ + do { \ + if (cond) { \ + log_func; \ + expr; \ + } \ + } while (0) +} // namespace optiling + +#endif // OPS_BUILT_IN_OP_TILING_ERROR_LOG_H_ diff --git a/csrc/deepep/ops2/utils/op_kernel/.DS_Store b/csrc/deepep/ops2/utils/op_kernel/.DS_Store new file mode 100644 index 00000000..0382382b Binary files /dev/null and b/csrc/deepep/ops2/utils/op_kernel/.DS_Store differ diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/.DS_Store b/csrc/deepep/ops2/utils/op_kernel/operator/.DS_Store new file mode 100644 index 00000000..9d2b9c80 Binary files /dev/null and b/csrc/deepep/ops2/utils/op_kernel/operator/.DS_Store differ diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h b/csrc/deepep/ops2/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h new file mode 100644 index 00000000..d5217c2d --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h @@ -0,0 +1,802 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Description: add combine kernel implement + * Author: Chen Cheng + * Create: 2025-07-21 + * Note: + * History: 2025-07-21 add combine kernel implement + */ +#ifndef CAM_MOE_DISTRIBUTE_COMBINE_H +#define CAM_MOE_DISTRIBUTE_COMBINE_H + +#include "kernel_operator.h" +#include "kernel_tiling/kernel_tiling.h" +#include "../../../../../../op_kernel/fused_deep_moe_base.h" +#include "../../../../../../op_kernel/fused_deep_moe_tiling.h" + +namespace MoeDistributeCombineImpl { +constexpr uint8_t BUFFER_NUM = 2; // multi-buf +constexpr uint32_t STATE_OFFSET = 512; +constexpr uint32_t STATE_SIZE = 1024 * 1024; // 1M +constexpr uint32_t RANK_SIZE_ON_WIN_512 = 512 * 1024; +constexpr uint32_t RANK_SIZE_ON_WIN_256 = 256 * 1024; +constexpr uint32_t TP_RANK_SIZE_ON_WIN = 0; +constexpr uint32_t UB_ALIGN = 32; +constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024; +constexpr uint8_t EP_DOMAIN = 0; +constexpr uint8_t TP_DOMAIN = 1; +constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024; +constexpr uint64_t STATE_WIN_OFFSET = 900 * 1024; +constexpr uint16_t SEND_SYNC_EVENT_ID = 9; +constexpr uint16_t RECV_SYNC_EVENT_ID = 10; +constexpr uint32_t OPT_RANK_OFFSET = 512; + +template +__aicore__ inline void SyncFunc() +{ + int32_t eventID = static_cast(GetTPipePtr()->FetchEventID(event)); + AscendC::SetFlag(eventID); + AscendC::WaitFlag(eventID); +} + +using namespace AscendC; + +struct CombineCalcInfo { + uint64_t expertPerSizeOnWin_; + uint32_t epRankId_; + uint32_t epWorldSize_; + uint32_t moeExpertPerRankNum_; + uint32_t sharedExpertRankNum_; + uint32_t axisH_; + uint32_t moeSendNum_; + bool isShardExpert_; + GM_ADDR epSendCount_; + __gm__ HcclOpResParam *epWinContext_; + uint64_t winDataSizeOffset_; +}; + +template +class CamMoeDistributeCombine +{ +public: + __aicore__ inline CamMoeDistributeCombine(){}; + __aicore__ inline void Init(GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, + GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, + const FusedDeepMoeTilingData *tilingData); + __aicore__ inline void Process(); + __aicore__ inline void AllToAllSend(); + __aicore__ inline void ReducePermute(); + + __aicore__ inline CombineCalcInfo &GetCalcInfo() + { + return calcInfo_; + } + + __aicore__ inline void TPipeSet(AscendC::TPipe *pipe) + { + tpipe_ = pipe; + } + +private: + __aicore__ inline void InitStatusTargetSum(); + __aicore__ inline void AlltoAllBuffInit(); + __aicore__ inline void ReduceScatterTrans(); + __aicore__ inline void SetWaitTpStatusAndDisPatch(); + __aicore__ inline void CustomAdd(LocalTensor &dst, LocalTensor &src0, + LocalTensor &src1, uint32_t dataCnt); + __aicore__ inline void ExpertAlltoAllDispatchInnerCopyAdd(uint32_t tokenNumLoop, uint32_t srcStartTokenIdx, + uint32_t ep, uint32_t expertIdx); + __aicore__ inline void ExpertAlltoAllDispatchCopyAdd(); + __aicore__ inline void LocalWindowCopy(); + __aicore__ inline void BuffInit(); + __aicore__ inline void SplitCoreCal(); + __aicore__ inline void SetStatus(); + __aicore__ inline void WaitDispatch(); + __aicore__ GM_ADDR GetWinAddrByRankId(const int32_t rankId, const uint8_t domain, const uint8_t expertLocalId = 0U) + { + if (domain == EP_DOMAIN) { + return (GM_ADDR)((epRankId_ == rankId) + ? epWinContext_->localWindowsIn + : ((HcclRankRelationResV2 *)(epWinContext_->remoteRes[rankId].nextDevicePtr)) + ->windowsIn) + + winDataSizeOffset_ + expertLocalId * expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET; + } else { + return (GM_ADDR)((tpRankId_ == rankId) + ? tpWinContext_->localWindowsIn + : ((HcclRankRelationResV2 *)(tpWinContext_->remoteRes[rankId].nextDevicePtr)) + ->windowsIn) + + winDataSizeOffset_ + rankId * OPT_RANK_OFFSET; + } + } + + __aicore__ GM_ADDR GetWinStateAddrByRankId(const int32_t rankId, const uint8_t domain) + { + if (domain == EP_DOMAIN) { + return (GM_ADDR)((epRankId_ == rankId) + ? epWinContext_->localWindowsExp + : ((HcclRankRelationResV2 *)(epWinContext_->remoteRes[rankId].nextDevicePtr)) + ->windowsExp) + + dataState_ * WIN_STATE_OFFSET; + } else { + return (GM_ADDR)((tpRankId_ == rankId) + ? tpWinContext_->localWindowsExp + : ((HcclRankRelationResV2 *)(tpWinContext_->remoteRes[rankId].nextDevicePtr)) + ->windowsExp) + + dataState_ * WIN_STATE_OFFSET; + } + } + + __aicore__ inline uint32_t MIN(uint32_t x, uint32_t y) + { + return (x < y) ? x : y; + } + + __aicore__ static void DoCombineRecv(void *ptr) + { + auto *combiner = (CamMoeDistributeCombine *)ptr; + combiner->ReducePermute(); + } + + TPipe *tpipe_{nullptr}; + GlobalTensor expandXGM_; + GlobalTensor expertIdsGM_; + GlobalTensor expandIdxGM_; + GlobalTensor epSendCountGM_; + GlobalTensor tpSendCountGM_; + GlobalTensor expandScalesGM_; + GlobalTensor expandOutGlobal_; + GlobalTensor rankWindow_; + GlobalTensor rankStates_; + GlobalTensor epStatusSpaceGlobalTensor_; + GlobalTensor tpStatusSpaceGlobalTensor_; + GlobalTensor tpRankWindow_; + GlobalTensor rowTmpGlobal_; + GM_ADDR workspaceGM_; + GM_ADDR epWindowGM_; + GM_ADDR epStatusSpaceGm_; + GM_ADDR tpWindowGM_; + GM_ADDR tpStatusSpaceGm_; + GM_ADDR stateGM_; + + LocalTensor winTpSendCountTensor_; + LocalTensor gmTpSendCountTensor_; + LocalTensor outTensor_; + LocalTensor winTpSendCountFloatTensor_; + LocalTensor gmTpSendCountFloatTensor_; + LocalTensor epSendCountLocal_; + + CombineCalcInfo calcInfo_; + uint32_t axisBS_{0}; + uint32_t axisMaxBs_{0}; + uint32_t axisH_{0}; + uint32_t axisK_{0}; + uint32_t aivNum_{0}; + uint32_t epWorldSize_{0}; + uint32_t tpWorldSize_{0}; + uint32_t epRankId_{0}; + uint32_t tpRankId_{0}; + uint32_t coreIdx_{0}; // aiv id + uint32_t sharedExpertRankNum_{0}; + uint32_t moeExpertNum_{0}; + uint32_t moeExpertPerRankNum_{0}; + uint32_t moeSendNum_{0}; // moeExpertPerRankNum_ * epWorldSize_ + uint32_t tpScatterNum_{0}; + uint32_t firstTpTokenEndIdx_{0}; + uint32_t firstTpTokenEndOffset_{0}; + uint32_t endTok_{0}; + __gm__ HcclOpResParam *epWinContext_{nullptr}; + __gm__ HcclOpResParam *tpWinContext_{nullptr}; + uint32_t epDataOffsetOnWin_{0}; + uint32_t tpDataOffsetOnWin_{0}; + uint32_t epStateOffsetOnWin_{0}; + uint32_t tpStateOffsetOnWin_{0}; + uint32_t axisHFloatSize_{0}; + uint32_t axisHExpandXTypeSize_{0}; + uint32_t bsKNum_{0}; + uint32_t startRankId_{0}; + uint32_t endRankId_{0}; + uint32_t sendRankNum_{0}; + uint32_t ubSize_{0}; + uint32_t dataState_{0}; + uint32_t stateOffset_{0}; + uint64_t winDataSizeOffset_{0}; + uint64_t expertPerSizeOnWin_{0}; + uint64_t totalWinSize_{0}; + TQueBind moeQueue_; + TQue moeSumQueue_; + TQueBind gmTpSendCountQueue_; + TQue gmTpSendCountInQueue_; + TQue winTpSendCountInQueue_; + TQue xOutQueue_; + TBuf<> readStateBuf_; + TBuf<> expertIdsBuf_; + TBuf<> expandScalesBuf_; + TBuf<> rowTmpFloatBuf_; + TBuf<> sumFloatBuf_; + TBuf<> mulBuf_; + TBuf<> sendCountBuf_; + TBuf<> indexCountsBuf_; + TBuf<> winTpSendCountFloatBuf_; + TBuf<> gmTpSendCountFloatBuf_; + TBuf<> tokenBuf_; + TBuf<> statusBuf_; + TBuf<> gatherMaskOutBuf_; // gather mask output buf + TBuf<> gatherTmpBuf_; + TBuf<> statusSumOutBuf_; + float sumTarget_{0.0}; + int32_t epStateValue_; + bool isShardExpert_{false}; +}; + +template +__aicore__ inline void CamMoeDistributeCombine::Init( + GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, GM_ADDR tpSendCount, GM_ADDR scales, + GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, const FusedDeepMoeTilingData *tilingData) +{ + tpipe_ = pipe; + coreIdx_ = GetBlockIdx(); + epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; + auto contextGM0 = AscendC::GetHcclContext(); + epWinContext_ = (__gm__ HcclOpResParam *)contextGM0; + GlobalTensor selfDataStatusTensor; + GM_ADDR statusDataSpaceGm = (GM_ADDR)epWinContext_->localWindowsExp; + selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET)); + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + selfDataStatusTensor[coreIdx_ * UB_ALIGN]); + __asm__ __volatile__(""); + dataState_ = selfDataStatusTensor(coreIdx_ * UB_ALIGN); + selfDataStatusTensor(coreIdx_ * UB_ALIGN) = 1 - dataState_; + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + selfDataStatusTensor[coreIdx_ * UB_ALIGN]); + __asm__ __volatile__(""); + pipe_barrier(PIPE_ALL); + + workspaceGM_ = workspaceGM; + expandXGM_.SetGlobalBuffer((__gm__ ExpandXType *)expandX); + expertIdsGM_.SetGlobalBuffer((__gm__ ExpandIdxType *)expertIds); + expandIdxGM_.SetGlobalBuffer((__gm__ ExpandIdxType *)expandIdx); + epSendCountGM_.SetGlobalBuffer((__gm__ int32_t *)epSendCount); + expandScalesGM_.SetGlobalBuffer((__gm__ float *)scales); + expandOutGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)XOut); + axisBS_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs; + axisH_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; + axisK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k; + aivNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum; + ubSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.totalUbSize; + sharedExpertRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum; + moeExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum; + moeExpertPerRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + epWorldSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize; + axisMaxBs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.globalBs / epWorldSize_; + moeSendNum_ = epWorldSize_ * moeExpertPerRankNum_; + tpWorldSize_ = 1; + tpRankId_ = 0; + totalWinSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.totalWinSize; + stateOffset_ = (moeSendNum_ > 512) ? (STATE_OFFSET / 2) : STATE_OFFSET; + expertPerSizeOnWin_ = + static_cast(axisMaxBs_) * static_cast(axisH_) * static_cast(sizeof(ExpandXType)); + winDataSizeOffset_ = static_cast(dataState_) * static_cast(moeSendNum_) * expertPerSizeOnWin_; + epWindowGM_ = GetWinAddrByRankId(epRankId_, EP_DOMAIN); + epStatusSpaceGm_ = GetWinStateAddrByRankId(epRankId_, EP_DOMAIN); + epStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)epStatusSpaceGm_); + epDataOffsetOnWin_ = epRankId_ * moeExpertPerRankNum_ * static_cast(expertPerSizeOnWin_); + epStateOffsetOnWin_ = epRankId_ * stateOffset_; + isShardExpert_ = (epRankId_ < sharedExpertRankNum_); + axisHFloatSize_ = axisH_ * sizeof(float); + axisHExpandXTypeSize_ = axisH_ * sizeof(ExpandXType); + bsKNum_ = axisBS_ * axisK_; + + if constexpr (IsNeedReduceScatter) { + tpSendCountGM_.SetGlobalBuffer((__gm__ int32_t *)tpSendCount); + tpWindowGM_ = GetWinAddrByRankId(tpRankId_, TP_DOMAIN); + tpStatusSpaceGm_ = GetWinStateAddrByRankId(tpRankId_, TP_DOMAIN); + tpStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)tpStatusSpaceGm_); + tpDataOffsetOnWin_ = tpRankId_ * TP_RANK_SIZE_ON_WIN; + tpStateOffsetOnWin_ = tpRankId_ * stateOffset_; + uint32_t tpScatterRankWinOffset = (tpRankId_ == 0) ? TP_RANK_SIZE_ON_WIN : 0; + GM_ADDR rankGM = tpWindowGM_ + tpScatterRankWinOffset; + tpRankWindow_.SetGlobalBuffer((__gm__ ExpandXType *)rankGM); + } + + InitStatusTargetSum(); + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + coreIdx_ = get_block_idx(); + } + SplitCoreCal(); + + calcInfo_.epRankId_ = epRankId_; + calcInfo_.epWorldSize_ = epWorldSize_; + calcInfo_.expertPerSizeOnWin_ = expertPerSizeOnWin_; + calcInfo_.moeExpertPerRankNum_ = moeExpertPerRankNum_; + calcInfo_.sharedExpertRankNum_ = sharedExpertRankNum_; + calcInfo_.axisH_ = axisH_; + calcInfo_.moeSendNum_ = moeSendNum_; + calcInfo_.isShardExpert_ = isShardExpert_; + calcInfo_.epSendCount_ = epSendCount; + calcInfo_.epWinContext_ = epWinContext_; + calcInfo_.winDataSizeOffset_ = winDataSizeOffset_; +} + +template +__aicore__ inline void CamMoeDistributeCombine::InitStatusTargetSum() +{ + // ep state + GlobalTensor selfStatusTensor; + selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(epStatusSpaceGm_ + SELF_STATE_OFFSET)); + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + selfStatusTensor[coreIdx_ * UB_ALIGN]); + __asm__ __volatile__(""); + int32_t state = selfStatusTensor(coreIdx_ * UB_ALIGN); + if (state == 0) { + sumTarget_ = static_cast(1.0); + selfStatusTensor(coreIdx_ * UB_ALIGN) = 0x3F800000; // 1.0f + epStateValue_ = 0x3F800000; // 1.0f + } else { + sumTarget_ = static_cast(0.0); + selfStatusTensor(coreIdx_ * UB_ALIGN) = 0; + epStateValue_ = 0; + } + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + selfStatusTensor[coreIdx_ * UB_ALIGN]); + __asm__ __volatile__(""); +} + +template +__aicore__ inline void CamMoeDistributeCombine::BuffInit() +{ + tpipe_->Reset(); + tpipe_->InitBuffer(readStateBuf_, UB_ALIGN); // 32 + uint32_t sendNumAlign = Ceil(moeSendNum_ * sizeof(int32_t), UB_ALIGN) * UB_ALIGN; + tpipe_->InitBuffer(sendCountBuf_, sendNumAlign); // epWorldSize_ * moeExpertPerRankNum_ * 4 + if constexpr (IsNeedReduceScatter) { + tpipe_->InitBuffer(winTpSendCountInQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 28K + tpipe_->InitBuffer(gmTpSendCountInQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 28K + tpipe_->InitBuffer(xOutQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 28K + if constexpr (AscendC::IsSameType::value) { + tpipe_->InitBuffer(winTpSendCountFloatBuf_, axisHFloatSize_); + tpipe_->InitBuffer(gmTpSendCountFloatBuf_, axisHFloatSize_); + winTpSendCountFloatTensor_ = winTpSendCountFloatBuf_.Get(); + gmTpSendCountFloatTensor_ = gmTpSendCountFloatBuf_.Get(); + } + } else { + tpipe_->InitBuffer(gmTpSendCountQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 28K + } + epSendCountLocal_ = sendCountBuf_.Get(); +} + +template +__aicore__ inline void CamMoeDistributeCombine::AlltoAllBuffInit() +{ + tpipe_->Reset(); + tpipe_->InitBuffer(readStateBuf_, UB_ALIGN); + tpipe_->InitBuffer(statusBuf_, sendRankNum_ * UB_ALIGN); + tpipe_->InitBuffer(expertIdsBuf_, axisBS_ * axisK_ * sizeof(int32_t)); + tpipe_->InitBuffer(expandScalesBuf_, axisBS_ * axisK_ * sizeof(float)); + tpipe_->InitBuffer(tokenBuf_, axisH_ * sizeof(ExpandXType)); + tpipe_->InitBuffer(rowTmpFloatBuf_, axisHFloatSize_); // 7168 * 4 = 28672 + tpipe_->InitBuffer(mulBuf_, axisHFloatSize_); // 7168 * 4 = 28672 + tpipe_->InitBuffer(sumFloatBuf_, axisHFloatSize_); // 7168 * 4 = 28672 + tpipe_->InitBuffer(indexCountsBuf_, axisBS_ * axisK_ * sizeof(int32_t)); + tpipe_->InitBuffer(moeSumQueue_, BUFFER_NUM, axisHExpandXTypeSize_); + tpipe_->InitBuffer(gatherMaskOutBuf_, epWorldSize_ * sizeof(float)); + tpipe_->InitBuffer(gatherTmpBuf_, sizeof(uint32_t)); // 4 + tpipe_->InitBuffer(statusSumOutBuf_, sizeof(float)); // 4 +} + +template +__aicore__ inline void CamMoeDistributeCombine::SplitCoreCal() +{ + sendRankNum_ = epWorldSize_ / aivNum_; + uint32_t remainderRankNum = epWorldSize_ % aivNum_; + startRankId_ = sendRankNum_ * coreIdx_; + if (coreIdx_ < remainderRankNum) { + sendRankNum_++; + startRankId_ += coreIdx_; + } else { + startRankId_ += remainderRankNum; + } + endRankId_ = startRankId_ + sendRankNum_; +} + +template +__aicore__ inline void CamMoeDistributeCombine::ReduceScatterTrans() +{ + __asm__ __volatile__(""); + DataCacheCleanAndInvalid(tpSendCountGM_[tpRankId_]); + __asm__ __volatile__(""); + uint32_t offset = tpSendCountGM_.GetValue(tpRankId_) * axisH_; + GlobalTensor dataCopyInGM = expandXGM_[offset]; + GM_ADDR rankGM = GetWinAddrByRankId(1 - tpRankId_, TP_DOMAIN) + tpDataOffsetOnWin_; + rankWindow_.SetGlobalBuffer((__gm__ ExpandXType *)rankGM); + uint32_t copyStartIdx = 0; + if (startRankId_ > 0) { + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + epSendCountGM_[epWorldSize_ + startRankId_ - 1]); + __asm__ __volatile__(""); + copyStartIdx = epSendCountGM_.GetValue(epWorldSize_ + startRankId_ - 1); + } + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + epSendCountGM_[epWorldSize_ + endRankId_ - 1]); + __asm__ __volatile__(""); + uint32_t copyEndIdx = epSendCountGM_.GetValue(epWorldSize_ + endRankId_ - 1); + LocalTensor tmpUb; + for (uint32_t tokenNumIdx = copyStartIdx; tokenNumIdx < copyEndIdx; tokenNumIdx++) { + tmpUb = moeQueue_.AllocTensor(); + DataCopy(tmpUb, dataCopyInGM[tokenNumIdx * axisH_], axisH_); + moeQueue_.EnQue(tmpUb); + tmpUb = moeQueue_.DeQue(); + DataCopy(rankWindow_[tokenNumIdx * axisH_], tmpUb, axisH_); + moeQueue_.FreeTensor(tmpUb); + } +} + +// 46 -> gm -> ub syncall win->gm add -> alltoall +// 2 -> win wait syncall gm -> ub win ->gm add -> alltoall +template +__aicore__ inline void CamMoeDistributeCombine::SetWaitTpStatusAndDisPatch() +{ + pipe_barrier(PIPE_ALL); + if (startRankId_ >= epWorldSize_) { + return; + } + if constexpr (IsNeedReduceScatter) { + uint32_t tpToRankId = 1 - tpRankId_; + pipe_barrier(PIPE_ALL); + LocalTensor statusFlagUb = readStateBuf_.Get(); + statusFlagUb(0) = sumTarget_; + SyncFunc(); + GlobalTensor tpWindowInstatusFp32Tensor_; + stateGM_ = GetWinStateAddrByRankId(tpToRankId, TP_DOMAIN) + coreIdx_ * stateOffset_; + tpWindowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)stateGM_); + DataCopy(tpWindowInstatusFp32Tensor_, statusFlagUb, 8UL); + SyncFunc(); + LocalTensor statusFp32Tensor_ = readStateBuf_.Get(); + float sumOfFlag = static_cast(-1.0); + uint32_t statusRankOffset = coreIdx_ * stateOffset_ / sizeof(float); // tp = 2 + while (sumOfFlag != sumTarget_) { + DataCopy(statusFp32Tensor_, tpStatusSpaceGlobalTensor_[statusRankOffset], 8); + SyncFunc(); + sumOfFlag = statusFp32Tensor_.GetValue(0); + SyncFunc(); + } + } + // Copy win gm->ub add ->alltoall send + ExpertAlltoAllDispatchCopyAdd(); + SyncFunc(); +} + +template +__aicore__ inline void CamMoeDistributeCombine::ExpertAlltoAllDispatchCopyAdd() +{ + if (startRankId_ >= epWorldSize_) { + return; + } + uint32_t curRankExpertNum = 0; + DataCopyExtParams epSendCntParams; + if (isShardExpert_) { + curRankExpertNum = 1; + epSendCntParams = {1U, static_cast(epWorldSize_ * sizeof(uint32_t)), 0U, 0U, 0U}; + } else { + curRankExpertNum = moeExpertPerRankNum_; + epSendCntParams = {1U, static_cast(moeSendNum_ * sizeof(uint32_t)), 0U, 0U, 0U}; + } + DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + DataCopyPad(epSendCountLocal_, epSendCountGM_, epSendCntParams, copyPadParams); + SyncFunc(); + uint32_t preCount = 0; + uint32_t startTokenIdx = 0; + uint32_t curTokenNum = 0; + + for (uint32_t expertIdx = 0U; expertIdx < curRankExpertNum; expertIdx++) { + uint32_t sendEpCount = endRankId_ - startRankId_; + for (uint32_t i = 0; i < sendEpCount; ++i) { + uint32_t ep = startRankId_ + (i + epRankId_) % sendEpCount; + if ((ep > 0) || (expertIdx > 0U)) { + preCount = epSendCountLocal_.GetValue(expertIdx * epWorldSize_ + ep - 1); + } else { + preCount = 0; + } + curTokenNum = epSendCountLocal_.GetValue(expertIdx * epWorldSize_ + ep) - preCount; + if (curTokenNum == 0) { + continue; + } + startTokenIdx = preCount * axisH_; + ExpertAlltoAllDispatchInnerCopyAdd(curTokenNum, startTokenIdx, ep, expertIdx); + } + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::ExpertAlltoAllDispatchInnerCopyAdd( + uint32_t tokenNumLoop, uint32_t srcStartTokenIdx, uint32_t ep, uint32_t expertIdx) +{ + GM_ADDR rankGM = GetWinAddrByRankId(ep, EP_DOMAIN, expertIdx) + epDataOffsetOnWin_; + if ((isShardExpert_) && (ep < sharedExpertRankNum_)) { + rankGM = GetWinAddrByRankId(epRankId_, EP_DOMAIN, expertIdx) + ep * moeExpertPerRankNum_ * expertPerSizeOnWin_; + } + rankWindow_.SetGlobalBuffer((__gm__ ExpandXType *)rankGM); + uint32_t dataCnt = axisH_; + for (uint32_t loopIdx = 0; loopIdx < tokenNumLoop; loopIdx++) { + if constexpr (IsNeedReduceScatter) { + gmTpSendCountTensor_ = gmTpSendCountInQueue_.AllocTensor(); + DataCopy(gmTpSendCountTensor_, expandXGM_[srcStartTokenIdx], dataCnt); + gmTpSendCountInQueue_.EnQue(gmTpSendCountTensor_); + + winTpSendCountTensor_ = winTpSendCountInQueue_.AllocTensor(); + DataCopy(winTpSendCountTensor_, tpRankWindow_[srcStartTokenIdx], dataCnt); + winTpSendCountInQueue_.EnQue(winTpSendCountTensor_); + + gmTpSendCountTensor_ = gmTpSendCountInQueue_.DeQue(); + winTpSendCountTensor_ = winTpSendCountInQueue_.DeQue(); + outTensor_ = xOutQueue_.AllocTensor(); + + CustomAdd(outTensor_, winTpSendCountTensor_, gmTpSendCountTensor_, dataCnt); + gmTpSendCountInQueue_.FreeTensor(gmTpSendCountTensor_); + winTpSendCountInQueue_.FreeTensor(winTpSendCountTensor_); + xOutQueue_.EnQue(outTensor_); + + outTensor_ = xOutQueue_.DeQue(); + DataCopy(rankWindow_[loopIdx * dataCnt], outTensor_, dataCnt); + xOutQueue_.FreeTensor(outTensor_); + } else { + gmTpSendCountTensor_ = gmTpSendCountQueue_.AllocTensor(); + DataCopy(gmTpSendCountTensor_, expandXGM_[srcStartTokenIdx], dataCnt); + ExpandXType val = expandXGM_[srcStartTokenIdx].GetValue(0); + gmTpSendCountQueue_.EnQue(gmTpSendCountTensor_); + gmTpSendCountTensor_ = gmTpSendCountQueue_.DeQue(); + DataCopy(rankWindow_[loopIdx * dataCnt], gmTpSendCountTensor_, dataCnt); + gmTpSendCountQueue_.FreeTensor(gmTpSendCountTensor_); + } + srcStartTokenIdx += dataCnt; + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::CustomAdd(LocalTensor &dst, + LocalTensor &src0, + LocalTensor &src1, + uint32_t dataCnt) +{ + if constexpr (AscendC::IsSameType::value) { + Cast(winTpSendCountFloatTensor_, src0, RoundMode::CAST_NONE, dataCnt); + Cast(gmTpSendCountFloatTensor_, src1, RoundMode::CAST_NONE, dataCnt); + pipe_barrier(PIPE_V); + Add(winTpSendCountFloatTensor_, winTpSendCountFloatTensor_, gmTpSendCountFloatTensor_, dataCnt); + pipe_barrier(PIPE_V); + Cast(dst, winTpSendCountFloatTensor_, RoundMode::CAST_ROUND, dataCnt); + } else { + Add(dst, src0, src1, dataCnt); + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::SetStatus() +{ + pipe_barrier(PIPE_ALL); + if (startRankId_ >= epWorldSize_) { + return; + } + + LocalTensor statusFlagUb = readStateBuf_.Get(); + statusFlagUb.SetValue(0, epStateValue_); + SyncFunc(); + + for (uint32_t epIdx = startRankId_; epIdx < endRankId_; epIdx++) { + stateGM_ = GetWinStateAddrByRankId(epIdx, EP_DOMAIN) + epStateOffsetOnWin_; + rankStates_.SetGlobalBuffer((__gm__ int32_t *)stateGM_); + DataCopy(rankStates_, statusFlagUb, 8); + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::WaitDispatch() +{ + if (startRankId_ < epWorldSize_) { + LocalTensor statusTensor = statusBuf_.Get(); + LocalTensor gatherMaskOutTensor = gatherMaskOutBuf_.Get(); + LocalTensor gatherTmpTensor = gatherTmpBuf_.Get(); + LocalTensor statusSumOutTensor = statusSumOutBuf_.Get(); + + gatherTmpTensor.SetValue(0, 1); + uint32_t mask = 1; // gatherMask + sum + uint64_t rsvdCnt = 0; + DataCopyParams intriParams{static_cast(sendRankNum_), 1, + static_cast((moeSendNum_ > 512) ? 7 : 15), 0}; // srcStride is 15 blocks + float sumOfFlag = static_cast(-1.0); + float minTarget = (sumTarget_ * sendRankNum_) - (float)0.5; + float maxTarget = (sumTarget_ * sendRankNum_) + (float)0.5; + SumParams sumParams{1, sendRankNum_, sendRankNum_}; + SyncFunc(); + while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) { + DataCopy(statusTensor, epStatusSpaceGlobalTensor_[startRankId_ * stateOffset_ / sizeof(float)], + intriParams); + SyncFunc(); + GatherMask(gatherMaskOutTensor, statusTensor, gatherTmpTensor, true, mask, + {1, (uint16_t)sendRankNum_, 1, 0}, rsvdCnt); + PipeBarrier(); + Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams); + SyncFunc(); + sumOfFlag = statusSumOutTensor.GetValue(0); + } + } + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID); + AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID); + } else { + SyncAll(); + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::LocalWindowCopy() +{ + uint32_t beginIndex = 0; + uint32_t endIndex = 0; + uint32_t processLen = 0; + uint32_t tokenOffset = 0; + if (axisBS_ < aivNum_) { + uint32_t aivNumPerToken = aivNum_ / axisBS_; // axisBS_ < aivNum_ + if (coreIdx_ >= (axisBS_ * aivNumPerToken)) { + return; + } + uint32_t tokenIndex = coreIdx_ / aivNumPerToken; + processLen = ((axisH_ / UB_ALIGN) / aivNumPerToken) * UB_ALIGN; + tokenOffset = processLen * (coreIdx_ % aivNumPerToken); + if ((coreIdx_ % aivNumPerToken) == (aivNumPerToken - 1)) { + processLen = axisH_ - ((aivNumPerToken - 1) * processLen); + } + beginIndex = tokenIndex; + endIndex = beginIndex + 1U; + } else { + uint32_t tokenPerAivNum = axisBS_ / aivNum_; + uint32_t remainderToken = axisBS_ % aivNum_; + beginIndex = tokenPerAivNum * coreIdx_; + if (coreIdx_ < remainderToken) { + tokenPerAivNum++; + beginIndex = tokenPerAivNum * coreIdx_; + } else { + beginIndex += remainderToken; + } + endIndex = beginIndex + tokenPerAivNum; + processLen = axisH_; + } + LocalTensor expertIdsLocal = expertIdsBuf_.Get(); + LocalTensor expandScalesLocal = expandScalesBuf_.Get(); + + LocalTensor rowTmpFloatLocal = rowTmpFloatBuf_.Get(); + LocalTensor mulBufLocal = mulBuf_.Get(); + LocalTensor sumFloatBufLocal = sumFloatBuf_.Get(); + + LocalTensor indexCountsLocal = indexCountsBuf_.Get(); + const DataCopyExtParams bskParams = {1U, static_cast(bsKNum_ * sizeof(uint32_t)), 0U, 0U, 0U}; + const DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + const DataCopyPadExtParams copyPadFloatParams{false, 0U, 0U, 0U}; + + DataCopyPad(indexCountsLocal, expandIdxGM_, bskParams, copyPadParams); + DataCopyPad(expertIdsLocal, expertIdsGM_, bskParams, copyPadParams); + DataCopyPad(expandScalesLocal, expandScalesGM_, bskParams, copyPadFloatParams); + SyncFunc(); + + for (uint32_t tokenIndex = beginIndex; tokenIndex < endIndex; tokenIndex++) { + uint32_t index = tokenIndex * axisK_; + SyncFunc(); + Duplicate(sumFloatBufLocal, (float)0, axisH_); + for (uint32_t i = 0; i < axisK_; i++) { + int32_t moeExpert = expertIdsLocal.GetValue(index); + float scaleVal = expandScalesLocal.GetValue(index); + GM_ADDR wAddr = (__gm__ uint8_t *)(epWindowGM_) + + expertPerSizeOnWin_ * moeExpertPerRankNum_ * sharedExpertRankNum_ + + expertPerSizeOnWin_ * moeExpert + indexCountsLocal.GetValue(index) * axisHExpandXTypeSize_ + + tokenOffset * sizeof(ExpandXType); + rowTmpGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)wAddr); + ExpandXType val = rowTmpGlobal_.GetValue(0); + LocalTensor tmpUb = moeSumQueue_.AllocTensor(); + DataCopy(tmpUb, rowTmpGlobal_, processLen); + moeSumQueue_.EnQue(tmpUb); + tmpUb = moeSumQueue_.DeQue(); + Cast(rowTmpFloatLocal, tmpUb, AscendC::RoundMode::CAST_NONE, processLen); + AscendC::PipeBarrier(); + AscendC::Muls(mulBufLocal, rowTmpFloatLocal, scaleVal, processLen); + AscendC::PipeBarrier(); + AscendC::Add(sumFloatBufLocal, sumFloatBufLocal, mulBufLocal, processLen); + index++; + moeSumQueue_.FreeTensor(tmpUb); + } + LocalTensor rowTmpLocal = tokenBuf_.Get(); + if (sharedExpertRankNum_ > 0U) { + uint32_t temp = (epRankId_ * axisBS_) / sharedExpertRankNum_; + uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, axisBS_) - 1 - epRankId_; + uint32_t preCnt = (moeOnShareRank + epRankId_) * axisBS_ / sharedExpertRankNum_ - + epRankId_ * axisBS_ / sharedExpertRankNum_; + __gm__ ExpandXType *shareAddr = + (__gm__ ExpandXType *)(epWindowGM_ + moeOnShareRank * expertPerSizeOnWin_ * moeExpertPerRankNum_) + + (tokenIndex - preCnt) * axisH_ + tokenOffset; + GlobalTensor shareTokGlobal; + shareTokGlobal.SetGlobalBuffer((__gm__ ExpandXType *)(shareAddr)); + SyncFunc(); + DataCopy(rowTmpLocal, shareTokGlobal, processLen); + SyncFunc(); + Cast(rowTmpFloatLocal, rowTmpLocal, AscendC::RoundMode::CAST_NONE, processLen); + AscendC::PipeBarrier(); + AscendC::Add(sumFloatBufLocal, sumFloatBufLocal, rowTmpFloatLocal, processLen); + } + // 结果搬出 + AscendC::PipeBarrier(); + LocalTensor sumBufLocal = tokenBuf_.Get(); + Cast(sumBufLocal, sumFloatBufLocal, AscendC::RoundMode::CAST_RINT, processLen); + SyncFunc(); + DataCopy(expandOutGlobal_[tokenIndex * axisH_ + tokenOffset], sumBufLocal, processLen); + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::Process() +{ + SyncAll(); + if constexpr (IsNeedReduceScatter) { + tpipe_->InitBuffer(moeQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 7168 * 2 * 2 = 28672 + ReduceScatterTrans(); + } + if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) { + BuffInit(); + SetWaitTpStatusAndDisPatch(); + } + AlltoAllBuffInit(); + SetStatus(); + WaitDispatch(); + LocalWindowCopy(); +} + +template +__aicore__ inline void CamMoeDistributeCombine::AllToAllSend() +{ + if constexpr (IsNeedReduceScatter) { + tpipe_->InitBuffer(moeQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 7168 * 2 * 2 = 28672 + ReduceScatterTrans(); + } + BuffInit(); + if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) { + SetWaitTpStatusAndDisPatch(); + AlltoAllBuffInit(); + } + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID); + AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID); + } else { + SyncAll(); + } + SetStatus(); + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID); + } else { + SyncAll(); + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::ReducePermute() +{ + AlltoAllBuffInit(); + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID); + } else { + SyncAll(); + } + + WaitDispatch(); + LocalWindowCopy(); + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID); + } +} +} // namespace MoeDistributeCombineImpl + +#endif // CAM_MOE_DISTRIBUTE_COMBINE_IMPL_H diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h b/csrc/deepep/ops2/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h new file mode 100644 index 00000000..baea77c3 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_dispatch.h @@ -0,0 +1,1050 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Description: CamMoeDistributeDispatch operator kernel function header file, for a3 + * Author: WANG Qiankun + * Create: 2025-05-29 + * Note: + * History: 2025-05-29 create CamMoeDistributeDispatch operator kernel function header file, for a3 + */ + +#ifndef CAM_MOE_DISTRIBUTE_DISPATCH_H +#define CAM_MOE_DISTRIBUTE_DISPATCH_H + +#include "kernel_operator.h" +#include "kernel_tiling/kernel_tiling.h" +#include "../../../../../../op_kernel/fused_deep_moe_base.h" +#include "../../../../../../op_kernel/fused_deep_moe_tiling.h" + +namespace MoeDistributeDispatchImpl { +constexpr uint8_t BUFFER_NUM = 2; +constexpr uint32_t STATE_OFFSET = 512; +constexpr uint32_t STATE_SIZE = 1024 * 1024; +constexpr uint32_t UB_ALIGN = 32; +constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024; +constexpr uint8_t COMM_NUM = 2; +constexpr uint8_t COMM_EP_IDX = 0; +constexpr uint8_t COMM_TP_IDX = 1; +constexpr uint32_t GATHER_NUM_PER_TIME = 6; + +constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024; +constexpr uint64_t STATE_WIN_OFFSET = 900 * 1024; +constexpr uint32_t TP_STATE_SIZE = 100 * 1024; +constexpr int CAM_MAX_RANK_SIZE = 384; +constexpr int64_t IPC_DATA_OFFSET = 2 * 1024 * 1024; +constexpr uint32_t OPT_RANK_OFFSET = 512; + +using countType = uint8_t; +constexpr uint32_t LOOP_OPT_MAX_BS = 64; +constexpr uint32_t LOOP_OPT_MAX_MOE_RANK = 256; +constexpr uint32_t TOPK_ELEM_COUNT_PER_BLOCK = UB_ALIGN / sizeof(int32_t); +constexpr uint32_t TABLE_ELEM_COUNT_PER_BLOCK = UB_ALIGN / sizeof(countType); +constexpr uint32_t INT32_NUM_PER_BLOCK = UB_ALIGN / sizeof(int32_t); + +template +__aicore__ inline void SyncFunc() +{ + int32_t eventID = static_cast(GetTPipePtr()->FetchEventID(event)); + AscendC::SetFlag(eventID); + AscendC::WaitFlag(eventID); +} + +#define TemplateDispatchTypeClass \ + typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \ + bool IsNeedAllgater +#define TemplateDispatchTypeFunc XType, ExpandXOutType, StaticQuant, DynamicQuant, IsSmoothScaleExist, IsNeedAllgater + +using namespace AscendC; +template +class CamMoeDistributeDispatch +{ +public: + __aicore__ inline CamMoeDistributeDispatch(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expandXOut, + GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut, + GM_ADDR sendCountsOut, GM_ADDR tpSendCountsOut, GM_ADDR workspaceGM, TPipe *pipe, + const FusedDeepMoeTilingData *tilingData); + __aicore__ inline void Process(); + +private: + __aicore__ inline void SendToSharedExpert(); + __aicore__ inline void SendToMoeExpert(); + __aicore__ inline void AlltoAllDispatch(); + __aicore__ inline void LocalWindowCopy(); + __aicore__ inline void QuantProcess(uint32_t expertIndex); + __aicore__ inline void LocalSharedExpertCopyWindow(uint32_t rankIndex, uint32_t tokenOffset, + uint32_t currendTokenIndex, uint32_t &dynamicScalesLocalIdx); + __aicore__ inline void SetStatus(); + __aicore__ inline void WaitDispatch(); + __aicore__ inline void GetCumSum(LocalTensor &inLocal, LocalTensor &outLocal, int32_t totalCount); + __aicore__ inline void CreateZeroTensor(LocalTensor &outTensor); + __aicore__ inline void AllGatherSetStatusAndWait(); + __aicore__ inline void ResetStatus(); + __aicore__ inline void QuantInit(GM_ADDR scales); + __aicore__ inline void AllgatherProcessOut(); + __aicore__ inline void UpdataMultiMoeTokenNumsOut(); + __aicore__ inline void UpdataTokenNumsOut(); + __aicore__ inline GM_ADDR GetWindAddrByRankId(uint8_t ctxIdx, const int32_t rankId) + { + uint32_t curRankId = ctxIdx == COMM_EP_IDX ? epRankId_ : tpRankId_; + if (curRankId == rankId) { + return (GM_ADDR)(winContext_[ctxIdx]->localWindowsIn) + winDataSizeOffset_ + rankId * OPT_RANK_OFFSET; + } + return (GM_ADDR)(((HcclRankRelationResV2 *)(winContext_[ctxIdx]->remoteRes[rankId].nextDevicePtr))->windowsIn) + + winDataSizeOffset_ + rankId * OPT_RANK_OFFSET; + } + + __aicore__ inline GM_ADDR GetWindStateAddrByRankId(uint8_t ctxIdx, const int32_t rankId) + { + uint32_t curRankId = ctxIdx == COMM_EP_IDX ? epRankId_ : tpRankId_; + if (curRankId == rankId) { + return (GM_ADDR)(winContext_[ctxIdx]->localWindowsExp) + dataState_ * WIN_STATE_OFFSET; + } + return (GM_ADDR)(((HcclRankRelationResV2 *)(winContext_[ctxIdx]->remoteRes[rankId].nextDevicePtr)) + ->windowsExp) + + dataState_ * WIN_STATE_OFFSET; + } + + __aicore__ inline uint32_t MIN(uint32_t x, uint32_t y) + { + return (x < y) ? x : y; + } + TPipe *tpipe_{nullptr}; + GlobalTensor xGMTensor_; + GlobalTensor expertIdsGMTensor_; + GlobalTensor scalesGMTensor_; + GlobalTensor expandXOutGMTensor_; + GlobalTensor dynamicScalesOutGMTensor_; + GlobalTensor expertTokenNumsOutGMTensor_; + GlobalTensor windowInQuantTensor_; + GlobalTensor windowInstatusTensor_; + GlobalTensor windowInstatusFp32Tensor_; + GlobalTensor winTpGatherOutGMTensor_; + GlobalTensor fpWinTpGatherOutGMTensor_; + GlobalTensor winTpEpCntGMTensor_; + LocalTensor xTmpTensor_; + LocalTensor tpTmpTensor_; + LocalTensor xInTensor_; + LocalTensor xOutTensor_; + LocalTensor xOutFp32Tensor_; + LocalTensor expertCountTensor_; + LocalTensor expertIdsTensor_; + LocalTensor receivestatusTensor_; + LocalTensor rowMaxTensor_; + LocalTensor statusTensor_; + LocalTensor statusFp32Tensor_; + LocalTensor smoothScalesTensor_; + LocalTensor dynamicScalesTensor_; + TBuf<> dynamicScalesBuf_; + TBuf<> expertCountBuf_; + TBuf<> expertIdsBuf_; + TBuf<> statusBuf_; + TBuf<> gatherMaskOutBuf_; // gather mask output buf + TBuf<> getTotalBuf_; // compute totalCnt + TBuf<> scalarBuf_; + TBuf<> rowMaxBuf_; + TBuf<> receiveDataCastFloatBuf_; + TBuf<> smoothScalesBuf_; + TQueBind xQueue_; + TQue xInQueue_; + TQue xOutQueue_; + GM_ADDR expandXOutGM_; + GM_ADDR expandIdxOutGM_; + GM_ADDR expertTokenNumsOutGM_; + GM_ADDR sendCountsOutGM_; + GM_ADDR sendTpCountOutGM_; + GM_ADDR statusSpaceGm_; + GM_ADDR windowGM_; + GM_ADDR tpWindowGM_; + GM_ADDR tpStatusWindowGM_; + GM_ADDR tpLocalWindowGM_; + GM_ADDR tpLocalStatusWindowGM_; + GlobalTensor peerMemsAddrGm_; + + uint32_t axisBS_{0}; + uint32_t axisMaxBS_{0}; + uint32_t axisH_{0}; + uint32_t axisK_{0}; + uint32_t aivNum_{0}; + uint32_t sharedUsedAivNum_{0}; + uint32_t moeUsedAivNum_{0}; + uint32_t epWorldSize_{0}; + uint32_t tpWorldSize_{0}; + uint32_t epRankId_{0}; + uint32_t tpGatherRankId_{0}; + uint32_t tpRankId_{0}; + uint32_t aivId_{0}; // aiv id + uint32_t sharedExpertRankNum_{0}; + uint32_t moeExpertRankNum_{0}; + uint32_t moeExpertNumPerRank_{0}; + uint32_t moeExpertNum_{0}; + uint32_t totalExpertNum_{0}; + uint32_t bufferSizePerRank_{0}; + uint32_t recvWinBlockNum_{0}; + uint32_t hSize_{0}; + uint32_t hOutSize_{0}; + uint32_t hCommuSize_{0}; + uint32_t scaleParamPad_{0}; + uint32_t axisHCommu_{0}; + uint32_t startExpertId_; + uint32_t endExpertId_; + uint32_t sendExpertNum_; + uint32_t localCopyCoreNum_; + uint32_t totalCnt_; + uint32_t lastCore_{0}; + uint32_t dataState_{0}; + uint32_t stateOffset_{0}; + uint64_t winDataSizeOffset_{0}; + uint64_t expertPerSizeOnWin_{0}; + uint64_t windyquantOffset_; + bool isShareExpertRank_ = false; + bool isQuant_ = false; + float sumTarget_; + uint64_t totalWinSize_{0}; + uint32_t gatherCount_{0}; + uint32_t expertTokenNumsType_{1}; + uint32_t preCnt_{0}; + __gm__ HcclOpResParam *winContext_[COMM_NUM]{nullptr, nullptr}; + + TBuf<> sendTableIdsBuf_; + LocalTensor tableLocalTensor_; + LocalTensor sendCountLocalTensor_; + uint32_t moeExpertRankNumAligned_; + uint32_t moeExpertRankNumInt16Aligned_; + uint32_t tableElemCount_; + bool enableAivOpt_{false}; +}; + +template +__aicore__ inline void CamMoeDistributeDispatch::Init( + GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, + GM_ADDR expertTokenNumsOut, GM_ADDR sendCountsOut, GM_ADDR tpSendCountsOut, GM_ADDR workspaceGM, TPipe *pipe, + const FusedDeepMoeTilingData *tilingData) +{ + tpipe_ = pipe; + aivId_ = GetBlockIdx(); + epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; + GlobalTensor selfDataStatusTensor; + GM_ADDR statusDataSpaceGm; + + winContext_[COMM_EP_IDX] = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); + winContext_[COMM_TP_IDX] = (__gm__ HcclOpResParam *)AscendC::GetHcclContext<1>(); + + statusDataSpaceGm = (GM_ADDR)(winContext_[COMM_EP_IDX]->localWindowsExp); + selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET)); + + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + selfDataStatusTensor[aivId_ * UB_ALIGN]); + __asm__ __volatile__(""); + dataState_ = selfDataStatusTensor(aivId_ * UB_ALIGN); + if (dataState_ == 0) { + selfDataStatusTensor(aivId_ * UB_ALIGN) = 1; + } else { + selfDataStatusTensor(aivId_ * UB_ALIGN) = 0; + } + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + selfDataStatusTensor[aivId_ * UB_ALIGN]); + __asm__ __volatile__(""); + pipe_barrier(PIPE_ALL); + axisBS_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs; + axisH_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; + epWorldSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize; + axisMaxBS_ = axisBS_; + moeExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum; + sharedExpertRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum; + expertTokenNumsType_ = 0; + totalWinSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.totalWinSize; + moeExpertRankNum_ = epWorldSize_ - sharedExpertRankNum_; + moeExpertNumPerRank_ = moeExpertNum_ / moeExpertRankNum_; + expertPerSizeOnWin_ = axisMaxBS_ * axisH_ * sizeof(XType); + winDataSizeOffset_ = dataState_ * epWorldSize_ * expertPerSizeOnWin_ * moeExpertNumPerRank_; + tpRankId_ = 0; + windowGM_ = GetWindAddrByRankId(COMM_EP_IDX, epRankId_); + statusSpaceGm_ = GetWindStateAddrByRankId(COMM_EP_IDX, epRankId_); + tpGatherRankId_ = tpRankId_ == 0 ? 1 : 0; + axisK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k; + aivNum_ = 48; + tpWorldSize_ = 1; + xGMTensor_.SetGlobalBuffer((__gm__ XType *)x); + expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)expertIds); + expandXOutGMTensor_.SetGlobalBuffer((__gm__ ExpandXOutType *)expandXOut); + dynamicScalesOutGMTensor_.SetGlobalBuffer((__gm__ float *)dynamicScalesOut); + expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)expertTokenNumsOut); + windowInQuantTensor_.SetGlobalBuffer((__gm__ ExpandXOutType *)windowGM_); + windowInstatusTensor_.SetGlobalBuffer((__gm__ int32_t *)(statusSpaceGm_)); + windowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)(statusSpaceGm_)); + if constexpr (IsNeedAllgater) { + tpLocalWindowGM_ = GetWindAddrByRankId(COMM_TP_IDX, tpRankId_); + tpLocalStatusWindowGM_ = GetWindStateAddrByRankId(COMM_TP_IDX, tpRankId_); + tpWindowGM_ = GetWindAddrByRankId(COMM_TP_IDX, tpGatherRankId_); + tpStatusWindowGM_ = GetWindStateAddrByRankId(COMM_TP_IDX, tpGatherRankId_); + winTpGatherOutGMTensor_.SetGlobalBuffer((__gm__ ExpandXOutType *)tpWindowGM_); + fpWinTpGatherOutGMTensor_.SetGlobalBuffer((__gm__ float *)tpWindowGM_); + winTpEpCntGMTensor_.SetGlobalBuffer((__gm__ int32_t *)(tpStatusWindowGM_ + TP_STATE_SIZE)); + } + expandXOutGM_ = expandXOut; + expandIdxOutGM_ = expandIdxOut; // no GlobalTensor + sendCountsOutGM_ = sendCountsOut; // no GlobalTensor + sendTpCountOutGM_ = tpSendCountsOut; + isQuant_ = StaticQuant | DynamicQuant; + hSize_ = axisH_ * sizeof(XType); + hOutSize_ = axisH_ * sizeof(ExpandXOutType); + scaleParamPad_ = (isQuant_ ? 128 : 0); + hCommuSize_ = hOutSize_ + scaleParamPad_; + axisHCommu_ = hCommuSize_ / sizeof(ExpandXOutType); + if (sharedExpertRankNum_ != 0) { + sharedUsedAivNum_ = aivNum_ / (axisK_ + 1); + if (sharedUsedAivNum_ == 0) { + sharedUsedAivNum_ = 1; + } + } + moeUsedAivNum_ = aivNum_ - sharedUsedAivNum_; + bufferSizePerRank_ = 32 * hSize_; + recvWinBlockNum_ = epWorldSize_ * moeExpertNumPerRank_; + isShareExpertRank_ = (epRankId_ < sharedExpertRankNum_) ? true : false; + windyquantOffset_ = epWorldSize_ * axisMaxBS_ * hOutSize_; + GlobalTensor selfStatusTensor; + selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusSpaceGm_ + SELF_STATE_OFFSET)); + DataCacheCleanAndInvalid( + selfStatusTensor[aivId_ * UB_ALIGN]); + int32_t state = selfStatusTensor(aivId_ * UB_ALIGN); + stateOffset_ = (recvWinBlockNum_ > 512) ? (STATE_OFFSET / 2) : STATE_OFFSET; + tpipe_->InitBuffer(statusBuf_, recvWinBlockNum_ * UB_ALIGN); // expertNum * 32B + statusTensor_ = statusBuf_.Get(); + Duplicate(statusTensor_, 0, recvWinBlockNum_ * 8); // 8 = UB_ALIGN / sizeof(int32_t) + if (state == 0) { + sumTarget_ = (float)1.0; + selfStatusTensor(aivId_ * UB_ALIGN) = 0x3F800000; + uint64_t mask[2] = {0x101010101010101, 0}; + Duplicate(statusTensor_, 0x3F800000, mask, recvWinBlockNum_ / 8, 1, 8); + } else { + sumTarget_ = 0.0; + selfStatusTensor(aivId_ * UB_ALIGN) = 0; + } + DataCacheCleanAndInvalid( + selfStatusTensor[aivId_ * UB_ALIGN]); + tpipe_->InitBuffer(xQueue_, BUFFER_NUM, hCommuSize_); // 14k *2 + if (isQuant_) { + QuantInit(scales); + } + uint32_t expertIdsSize = axisBS_ * axisK_ * sizeof(int32_t); // 32 alignment + tpipe_->InitBuffer(expertIdsBuf_, expertIdsSize); // BS * K * 4 + expertIdsTensor_ = expertIdsBuf_.Get(); + tpipe_->InitBuffer(expertCountBuf_, expertIdsSize); // BS * K * 4 + expertCountTensor_ = expertCountBuf_.Get(); + + tpipe_->InitBuffer(gatherMaskOutBuf_, recvWinBlockNum_ * sizeof(float)); // worldsize * 4B + tpipe_->InitBuffer(getTotalBuf_, epWorldSize_ * moeExpertNumPerRank_ * sizeof(int32_t)); + tpipe_->InitBuffer(scalarBuf_, UB_ALIGN * 2); // 72B + + moeExpertRankNumAligned_ = Ceil(moeExpertNum_, TABLE_ELEM_COUNT_PER_BLOCK) * TABLE_ELEM_COUNT_PER_BLOCK; + if (axisBS_ <= LOOP_OPT_MAX_BS && moeExpertRankNumAligned_ <= LOOP_OPT_MAX_MOE_RANK && + axisK_ % TOPK_ELEM_COUNT_PER_BLOCK == 0) { + enableAivOpt_ = true; + moeExpertRankNumInt16Aligned_ = moeExpertRankNumAligned_ / 2; + tableElemCount_ = (axisBS_ + 1) * moeExpertRankNumAligned_; + + tpipe_->InitBuffer(sendTableIdsBuf_, tableElemCount_ * sizeof(countType)); + tableLocalTensor_ = sendTableIdsBuf_.Get(); + sendCountLocalTensor_ = tableLocalTensor_[axisBS_ * moeExpertRankNumAligned_]; + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::QuantInit(GM_ADDR scales) +{ + tpipe_->InitBuffer(xInQueue_, BUFFER_NUM, hSize_); // 14K *2 + tpipe_->InitBuffer(xOutQueue_, BUFFER_NUM, hCommuSize_); // 7K *2 + scalesGMTensor_.SetGlobalBuffer((__gm__ float *)scales); + uint32_t hFp32Size = axisH_ * sizeof(float); + if constexpr (DynamicQuant) { + tpipe_->InitBuffer(rowMaxBuf_, UB_ALIGN); // 32B + } + tpipe_->InitBuffer(receiveDataCastFloatBuf_, 1 * hFp32Size); // 28KB + tpipe_->InitBuffer(smoothScalesBuf_, axisH_ * sizeof(float)); // 28KB + smoothScalesTensor_ = smoothScalesBuf_.Get(); + tpipe_->InitBuffer(dynamicScalesBuf_, axisBS_ * sizeof(float)); // 32 * 4 + dynamicScalesTensor_ = dynamicScalesBuf_.Get(); +} + +template +__aicore__ inline void CamMoeDistributeDispatch::SendToSharedExpert() +{ + uint32_t sendTokenNum = axisBS_ / sharedUsedAivNum_; + uint32_t remainderTokenNum = axisBS_ % sharedUsedAivNum_; + uint32_t newAivId = aivId_ - moeUsedAivNum_; + uint32_t startTokenId = sendTokenNum * newAivId; + if (newAivId < remainderTokenNum) { + sendTokenNum += 1; + startTokenId += newAivId; + } else { + startTokenId += remainderTokenNum; + } + if (startTokenId >= axisBS_) { + return; + } + uint32_t endTokenId = startTokenId + sendTokenNum; + for (uint32_t tokenShuffleIndex = 0; tokenShuffleIndex < sendTokenNum; ++tokenShuffleIndex) { + uint32_t tokenIndex = startTokenId + ((tokenShuffleIndex + epRankId_) % sendTokenNum); + uint32_t temp = (epRankId_ * axisBS_) / sharedExpertRankNum_; + uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, axisBS_) - 1 - epRankId_; // dst + uint32_t preCnt = + (moeOnShareRank + epRankId_) * axisBS_ / sharedExpertRankNum_ - epRankId_ * axisBS_ / sharedExpertRankNum_; + GlobalTensor dstWinGMTensor; + dstWinGMTensor.SetGlobalBuffer((__gm__ ExpandXOutType *)(GetWindAddrByRankId(COMM_EP_IDX, moeOnShareRank) + + expertPerSizeOnWin_ * epRankId_)); + if constexpr (DynamicQuant || StaticQuant) { + xInTensor_ = xInQueue_.AllocTensor(); + DataCopy(xInTensor_, xGMTensor_[tokenIndex * axisH_], axisH_); + xInQueue_.EnQue(xInTensor_); + xInTensor_ = xInQueue_.DeQue(); + xOutTensor_ = xOutQueue_.AllocTensor(); + QuantProcess(0); + xOutQueue_.EnQue(xOutTensor_); + + xOutTensor_ = xOutQueue_.DeQue(); + if (isShareExpertRank_) { + xOutFp32Tensor_ = xOutTensor_.template ReinterpretCast(); + DataCopyExtParams dataCopyParamsFloat = {1U, sizeof(float), 0U, 0U, 0U}; + DataCopyPad(dynamicScalesOutGMTensor_[tokenIndex], xOutFp32Tensor_[axisH_ / sizeof(float)], + dataCopyParamsFloat); + if constexpr (IsNeedAllgater) { + DataCopy(winTpGatherOutGMTensor_[tokenIndex * axisHCommu_], xOutTensor_, axisHCommu_); + } + DataCopy(expandXOutGMTensor_[tokenIndex * axisH_], xOutTensor_, axisH_); + } else { + DataCopy(dstWinGMTensor[(tokenIndex - preCnt) * axisHCommu_], xOutTensor_, axisHCommu_); + } + xOutQueue_.FreeTensor(xOutTensor_); + } else { + xTmpTensor_ = xQueue_.AllocTensor(); + DataCopy(xTmpTensor_, xGMTensor_[tokenIndex * axisH_], axisH_); + xQueue_.EnQue(xTmpTensor_); + xTmpTensor_ = xQueue_.DeQue(); + if (isShareExpertRank_) { + if constexpr (IsNeedAllgater) { + DataCopy(winTpGatherOutGMTensor_[tokenIndex * axisHCommu_], xTmpTensor_, axisHCommu_); + } + DataCopy(expandXOutGMTensor_[tokenIndex * axisHCommu_], xTmpTensor_, axisHCommu_); + } else { + DataCopy(dstWinGMTensor[(tokenIndex - preCnt) * axisHCommu_], xTmpTensor_, axisHCommu_); + } + xQueue_.FreeTensor(xTmpTensor_); + } + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::SendToMoeExpert() +{ + uint32_t expertIdsCnt = axisBS_ * axisK_; + uint32_t sendTokenNum = expertIdsCnt / moeUsedAivNum_; + uint32_t remainderTokenNum = expertIdsCnt % moeUsedAivNum_; + uint32_t startTokenId = sendTokenNum * aivId_; + if (aivId_ < remainderTokenNum) { + sendTokenNum += 1; + startTokenId += aivId_; + } else { + startTokenId += remainderTokenNum; + } + uint32_t endTokenId = startTokenId + sendTokenNum; + GlobalTensor dstWinGMTensor; + for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) { + uint32_t dstExpertId = expertIdsTensor_(tokenIndex); + uint32_t tempRankId = dstExpertId / moeExpertNumPerRank_ + sharedExpertRankNum_; + GM_ADDR rankGM = (__gm__ uint8_t *)(GetWindAddrByRankId(COMM_EP_IDX, tempRankId) + + (expertPerSizeOnWin_ * + (epRankId_ * moeExpertNumPerRank_ + dstExpertId % moeExpertNumPerRank_)) + + hCommuSize_ * expertCountTensor_(tokenIndex)); + dstWinGMTensor.SetGlobalBuffer((__gm__ ExpandXOutType *)rankGM); + if constexpr (DynamicQuant || StaticQuant) { + xInTensor_ = xInQueue_.AllocTensor(); + DataCopy(xInTensor_, xGMTensor_[tokenIndex / axisK_ * axisH_], axisH_); + xInQueue_.EnQue(xInTensor_); + xInTensor_ = xInQueue_.DeQue(); + xOutTensor_ = xOutQueue_.AllocTensor(); + uint32_t expertIndex = sharedExpertRankNum_ != 0 ? (dstExpertId + 1) : dstExpertId; + QuantProcess(expertIndex); + xOutQueue_.EnQue(xOutTensor_); + + xOutTensor_ = xOutQueue_.DeQue(); + DataCopy(dstWinGMTensor, xOutTensor_, axisHCommu_); + xOutQueue_.FreeTensor(xOutTensor_); + } else { + xTmpTensor_ = xQueue_.AllocTensor(); + DataCopy(xTmpTensor_, xGMTensor_[tokenIndex / axisK_ * axisH_], axisH_); + xQueue_.EnQue(xTmpTensor_); + xTmpTensor_ = xQueue_.DeQue(); + DataCopy(dstWinGMTensor, xTmpTensor_, axisHCommu_); + xQueue_.FreeTensor(xTmpTensor_); + } + } + if (aivId_ == (moeUsedAivNum_ - 1) && (!enableAivOpt_)) { + GlobalTensor expandIdxGMTensor; + expandIdxGMTensor.SetGlobalBuffer((__gm__ int32_t *)expandIdxOutGM_); + DataCopyExtParams expertIdsCntParams = {1U, static_cast(expertIdsCnt * sizeof(uint32_t)), 0U, 0U, 0U}; + DataCopyPad(expandIdxGMTensor, expertCountTensor_, expertIdsCntParams); + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::AlltoAllDispatch() +{ + uint32_t expertIdsCnt = axisBS_ * axisK_; + DataCopyExtParams expertIdsCntParams = {1U, static_cast(expertIdsCnt * sizeof(uint32_t)), 0U, 0U, 0U}; + DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + DataCopyPad(expertIdsTensor_, expertIdsGMTensor_, expertIdsCntParams, copyPadParams); + AscendC::TQueSync expertCntLocalSync; + expertCntLocalSync.SetFlag(0); + expertCntLocalSync.WaitFlag(0); + if (enableAivOpt_) { + LocalTensor tableInt16LocalTensor_ = tableLocalTensor_.template ReinterpretCast(); + Duplicate(tableInt16LocalTensor_, (int16_t)0, tableElemCount_ / 2); + SyncFunc(); + for (int tokenIndex = 0; tokenIndex < expertIdsCnt; ++tokenIndex) { + int expertId = expertIdsTensor_(tokenIndex); + tableLocalTensor_((tokenIndex / axisK_ + 1) * moeExpertRankNumAligned_ + expertId) = 1; + } + pipe_barrier(PIPE_ALL); + + uint32_t sendTokenNum = expertIdsCnt / moeUsedAivNum_; + uint32_t remainderTokenNum = expertIdsCnt % moeUsedAivNum_; + uint32_t startTokenId = sendTokenNum * aivId_; + if (aivId_ < remainderTokenNum) { + sendTokenNum += 1; + startTokenId += aivId_; + } else { + startTokenId += remainderTokenNum; + } + uint32_t endTokenId = startTokenId + sendTokenNum; + uint32_t startTokenRow = startTokenId / axisK_; + uint32_t endTokenRow = (endTokenId + axisK_ - 1) / axisK_; + + for (int row = 1; row <= axisBS_; ++row) { + Add(tableInt16LocalTensor_[row * moeExpertRankNumInt16Aligned_], + tableInt16LocalTensor_[row * moeExpertRankNumInt16Aligned_], + tableInt16LocalTensor_[(row - 1) * moeExpertRankNumInt16Aligned_], moeExpertRankNumInt16Aligned_); + pipe_barrier(PIPE_V); + } + + GlobalTensor expandIdxGMTensor; + if (aivId_ < moeUsedAivNum_) { + SyncFunc(); + for (int row = startTokenRow; row < endTokenRow; ++row) { + for (int expertIndex = 0; expertIndex < axisK_; ++expertIndex) { + int32_t expertId = expertIdsTensor_(row * axisK_ + expertIndex); + expertCountTensor_(row * axisK_ + expertIndex) = + (int32_t)tableLocalTensor_(row * moeExpertRankNumAligned_ + expertId); + } + SyncFunc(); + expandIdxGMTensor.SetGlobalBuffer( + (__gm__ int32_t *)(expandIdxOutGM_ + row * axisK_ * sizeof(uint32_t))); + DataCopy(expandIdxGMTensor, expertCountTensor_[row * axisK_], axisK_); + } + } + + uint32_t preTotalExpertNum = sharedExpertRankNum_ + moeExpertNum_; + uint32_t preSendExpertNum = preTotalExpertNum / aivNum_; + uint32_t preRemainderRankNum = preTotalExpertNum % aivNum_; + uint32_t preStartExpertId = preSendExpertNum * aivId_; + if (aivId_ < preRemainderRankNum) { + preSendExpertNum += 1; + preStartExpertId += aivId_; + } else { + preStartExpertId += preRemainderRankNum; + } + uint32_t preEndExpertId = preStartExpertId + preSendExpertNum; + preStartExpertId = preStartExpertId >= sharedExpertRankNum_ ? preStartExpertId : sharedExpertRankNum_; + + SyncFunc(); + for (int32_t tmpExpertId = preStartExpertId; tmpExpertId < preEndExpertId; ++tmpExpertId) { + statusTensor_(tmpExpertId * INT32_NUM_PER_BLOCK + 1) = + (int32_t)sendCountLocalTensor_(tmpExpertId - sharedExpertRankNum_); + } + } else { + for (uint32_t tokenIndex = 0; tokenIndex < expertIdsCnt; ++tokenIndex) { + int32_t expertId = expertIdsTensor_(tokenIndex) + sharedExpertRankNum_; + expertCountTensor_(tokenIndex) = statusTensor_(expertId * INT32_NUM_PER_BLOCK + 1); + statusTensor_(expertId * INT32_NUM_PER_BLOCK + 1)++; + } + } + if (!isShareExpertRank_) { + for (uint32_t curSatatusExpId = 0; curSatatusExpId < sharedExpertRankNum_; ++curSatatusExpId) { + int32_t curExpertCnt = (curSatatusExpId + 1 + epRankId_) * axisBS_ / sharedExpertRankNum_ - + (curSatatusExpId + epRankId_) * axisBS_ / sharedExpertRankNum_; + statusTensor_((curSatatusExpId)*INT32_NUM_PER_BLOCK + 1) = curExpertCnt; + } + } + + if ((sharedExpertRankNum_ != 0) && (aivId_ >= moeUsedAivNum_)) { + SendToSharedExpert(); + return; + } + SendToMoeExpert(); +} + +template +__aicore__ inline void CamMoeDistributeDispatch::SetStatus() +{ + pipe_barrier(PIPE_ALL); + SyncAll(); + totalExpertNum_ = sharedExpertRankNum_ + moeExpertNum_; + sendExpertNum_ = totalExpertNum_ / aivNum_; + uint32_t remainderRankNum = totalExpertNum_ % aivNum_; + startExpertId_ = sendExpertNum_ * aivId_; + if (aivId_ < remainderRankNum) { + sendExpertNum_ += 1; + startExpertId_ += aivId_; + } else { + startExpertId_ += remainderRankNum; + } + endExpertId_ = startExpertId_ + sendExpertNum_; + if (startExpertId_ >= totalExpertNum_) { + return; + } + GlobalTensor rankGMTensor; + uint32_t offset = stateOffset_ * epRankId_; + for (uint32_t rankIndex = startExpertId_; rankIndex < endExpertId_; ++rankIndex) { + uint32_t dstRankId = rankIndex; + if (moeExpertNumPerRank_ > 1 && (rankIndex >= sharedExpertRankNum_)) { + dstRankId = ((rankIndex - sharedExpertRankNum_) / moeExpertNumPerRank_ + sharedExpertRankNum_); + offset = + (epRankId_ + (rankIndex - sharedExpertRankNum_) % moeExpertNumPerRank_ * epWorldSize_) * stateOffset_; + } + GM_ADDR rankGM = (__gm__ uint8_t *)(GetWindStateAddrByRankId(COMM_EP_IDX, dstRankId) + offset); + rankGMTensor.SetGlobalBuffer((__gm__ int32_t *)rankGM); + DataCopy(rankGMTensor, statusTensor_[rankIndex * 8], 8UL); + } + SyncFunc(); +} + +template +__aicore__ inline void CamMoeDistributeDispatch::QuantProcess(uint32_t expertIndex) +{ + float dynamicScale = 0.0; + LocalTensor floatLocalTemp; + floatLocalTemp = receiveDataCastFloatBuf_.Get(); + Cast(floatLocalTemp, xInTensor_, RoundMode::CAST_NONE, axisH_); + xInQueue_.FreeTensor(xInTensor_); + pipe_barrier(PIPE_V); + if constexpr (IsSmoothScaleExist) { + if constexpr (DynamicQuant) { + SyncFunc(); + } + DataCopy(smoothScalesTensor_, scalesGMTensor_[expertIndex * axisH_], axisH_); + SyncFunc(); + Mul(floatLocalTemp, floatLocalTemp, smoothScalesTensor_, axisH_); + pipe_barrier(PIPE_V); + } + if constexpr (DynamicQuant) { + LocalTensor floatLocalAbsTemp = smoothScalesBuf_.Get(); + rowMaxTensor_ = rowMaxBuf_.Get(); + Abs(floatLocalAbsTemp, floatLocalTemp, axisH_); + pipe_barrier(PIPE_V); + ReduceMax(rowMaxTensor_, floatLocalAbsTemp, floatLocalAbsTemp, axisH_, false); + SyncFunc(); + dynamicScale = float(127.0) / rowMaxTensor_.GetValue(0); + SyncFunc(); + Muls(floatLocalTemp, floatLocalTemp, dynamicScale, axisH_); + pipe_barrier(PIPE_V); + } + LocalTensor halfLocalTemp = floatLocalTemp.ReinterpretCast(); + LocalTensor int32LocalTemp = floatLocalTemp.ReinterpretCast(); + Cast(int32LocalTemp, floatLocalTemp, RoundMode::CAST_RINT, axisH_); + pipe_barrier(PIPE_V); + SetDeqScale((half)1.000000e+00f); + PipeBarrier(); + Cast(halfLocalTemp, int32LocalTemp, RoundMode::CAST_ROUND, axisH_); + pipe_barrier(PIPE_V); + Cast(xOutTensor_, halfLocalTemp, RoundMode::CAST_TRUNC, axisH_); + floatLocalTemp = xOutTensor_.template ReinterpretCast(); + floatLocalTemp.SetValue(axisH_ / sizeof(float), float(1.0) / dynamicScale); // int8->float32 +} + +template +__aicore__ inline void CamMoeDistributeDispatch::LocalSharedExpertCopyWindow( + uint32_t rankIndex, uint32_t tokenOffset, uint32_t currendTokenIndex, uint32_t &dynamicScalesLocalIdx) +{ + xTmpTensor_ = xQueue_.AllocTensor(); + DataCopy(xTmpTensor_, + windowInQuantTensor_[rankIndex * (expertPerSizeOnWin_ / sizeof(ExpandXOutType)) + + currendTokenIndex * axisHCommu_], + axisHCommu_); + xQueue_.EnQue(xTmpTensor_); + xTmpTensor_ = xQueue_.DeQue(); + if constexpr (DynamicQuant || StaticQuant) { + pipe_barrier(PIPE_ALL); + xOutFp32Tensor_ = xTmpTensor_.template ReinterpretCast(); + dynamicScalesTensor_.SetValue(dynamicScalesLocalIdx++, xOutFp32Tensor_.GetValue(axisH_ / sizeof(float))); + pipe_barrier(PIPE_ALL); + } + if constexpr (IsNeedAllgater) { + DataCopy(winTpGatherOutGMTensor_[tokenOffset * axisH_], xTmpTensor_, axisH_); + } + DataCopy(expandXOutGMTensor_[tokenOffset * axisH_], xTmpTensor_, axisH_); + xQueue_.FreeTensor(xTmpTensor_); +} + +template +__aicore__ inline void CamMoeDistributeDispatch::WaitDispatch() +{ + uint32_t rscvStatusNum = isShareExpertRank_ ? epWorldSize_ : recvWinBlockNum_; + uint32_t recStatusNumPerCore = rscvStatusNum / aivNum_; + uint32_t remainderRankNum = rscvStatusNum % aivNum_; + uint32_t startStatusIndex = recStatusNumPerCore * aivId_; + if (aivId_ < remainderRankNum) { + recStatusNumPerCore += 1; + startStatusIndex += aivId_; + } else { + startStatusIndex += remainderRankNum; + } + if (startStatusIndex >= rscvStatusNum) { + SyncAll(); + return; + } + LocalTensor gatherMaskOutTensor = gatherMaskOutBuf_.Get(); + LocalTensor gatherTmpTensor = scalarBuf_.GetWithOffset(UB_ALIGN / sizeof(uint32_t), 0); + gatherTmpTensor.SetValue(0, 1); + LocalTensor statusSumOutTensor = scalarBuf_.GetWithOffset(UB_ALIGN / sizeof(float), UB_ALIGN); + statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + uint32_t mask = 1; // gatherMask + sum + uint64_t rsvdCnt = 0; + SumParams sumParams{1, recStatusNumPerCore, recStatusNumPerCore}; + float sumOfFlag = static_cast(-1.0); + float minTarget = (sumTarget_ * recStatusNumPerCore) - (float)0.5; + float maxTarget = (sumTarget_ * recStatusNumPerCore) + (float)0.5; + DataCopyParams intriParams{static_cast(recStatusNumPerCore), 1, + static_cast((recvWinBlockNum_ > 512) ? 7 : 15), 0}; // srcStride is 15 blocks + SyncFunc(); + while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) { + DataCopy(statusFp32Tensor_, windowInstatusFp32Tensor_[startStatusIndex * stateOffset_ / sizeof(float)], + intriParams); + SyncFunc(); + GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, mask, + {1, (uint16_t)recStatusNumPerCore, 1, 0}, rsvdCnt); + pipe_barrier(PIPE_V); + Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams); + SyncFunc(); + sumOfFlag = statusSumOutTensor.GetValue(0); + } + SyncAll(); +} + +template +__aicore__ inline void CamMoeDistributeDispatch::GetCumSum(LocalTensor &inLocal, + LocalTensor &outLocal, + int32_t totalCount) +{ + statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + DataCopyParams intriParams{static_cast(recvWinBlockNum_), 1, + static_cast((recvWinBlockNum_ > 512) ? 7 : 15), 0}; // srcStride is 15 blocks + DataCopy(statusTensor_, windowInstatusTensor_, intriParams); + SyncFunc(); + if (isShareExpertRank_) { + for (uint32_t curSatatusExpId = 0; curSatatusExpId < sharedExpertRankNum_; ++curSatatusExpId) { + int32_t curExpertCnt = (curSatatusExpId + 1 + epRankId_) * axisBS_ / sharedExpertRankNum_ - + (curSatatusExpId + epRankId_) * axisBS_ / sharedExpertRankNum_; + statusTensor_((curSatatusExpId)*INT32_NUM_PER_BLOCK + 1) = curExpertCnt; + } + } + outLocal = gatherMaskOutBuf_.Get(); + LocalTensor getTotalLocal = getTotalBuf_.Get(); + // gather mask + TBuf<> gatherTmpBuf; + TBuf<> workLocalBuf; + tpipe_->InitBuffer(gatherTmpBuf, sizeof(uint32_t) * recvWinBlockNum_ / 4); + LocalTensor gatherTmpTensor = gatherTmpBuf.Get(); + Duplicate(gatherTmpTensor, (uint32_t)33686018, recvWinBlockNum_ / 4); // 0000 0010 0000 0010 0000 0010 0000 0010 + PipeBarrier(); + uint32_t mask = recvWinBlockNum_ * 8; // 512 / 32 + uint64_t rsvdCnt = 0; + GatherMask(outLocal, inLocal, gatherTmpTensor, true, mask, {1, 1, 0, 0}, rsvdCnt); + + int typeSize = sizeof(int32_t); + int32_t elementsPerBlock = 32 / typeSize; + int32_t elementsPerRepeat = 256 / typeSize; + int32_t firstMaxRepeat = epWorldSize_; + int32_t iter1OutputCount = firstMaxRepeat; + int32_t iter1AlignEnd = ((iter1OutputCount + elementsPerBlock - 1) / elementsPerBlock) * elementsPerBlock; + int32_t finalWorkLocalNeedSize = iter1AlignEnd; + tpipe_->InitBuffer(workLocalBuf, finalWorkLocalNeedSize * sizeof(int32_t)); + LocalTensor workLocalTensor = workLocalBuf.Get(); + LocalTensor tmpFp32 = outLocal.ReinterpretCast(); + PipeBarrier(); + ReduceSum(getTotalLocal, tmpFp32, workLocalTensor, epWorldSize_); + totalCnt_ = getTotalLocal.ReinterpretCast().GetValue(0); + PipeBarrier(); + ReduceSum(tmpFp32, tmpFp32, workLocalTensor, totalCount); + PipeBarrier(); +} + +template +__aicore__ inline void +CamMoeDistributeDispatch::CreateZeroTensor(LocalTensor &outLocal) +{ + TBuf<> outBuf; + tpipe_->InitBuffer(outBuf, UB_ALIGN); + outLocal = outBuf.Get(); + for (uint32_t i = 0; i < 2; i++) { + outLocal.SetValue(i, 0); + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::LocalWindowCopy() +{ + uint32_t totalMoeExpert = 0; + LocalTensor outCountLocal; + if (isShareExpertRank_) { + totalMoeExpert = epWorldSize_; + } else { + totalMoeExpert = epWorldSize_ * moeExpertNumPerRank_; + } + sendExpertNum_ = totalMoeExpert / aivNum_; + uint32_t remainderRankNum = totalMoeExpert % aivNum_; + startExpertId_ = sendExpertNum_ * aivId_; + if (aivId_ < remainderRankNum) { + sendExpertNum_ += 1; + startExpertId_ += aivId_; + } else { + startExpertId_ += remainderRankNum; + } + endExpertId_ = startExpertId_ + sendExpertNum_; + if (startExpertId_ >= totalMoeExpert) { + return; + } + GetCumSum(statusTensor_, outCountLocal, startExpertId_ + 1); + uint32_t index = 0; + uint32_t beginIdx = 0; + DataCopyExtParams dataCopyParamsFloat = {1U, sizeof(float), 0U, 0U, 0U}; + for (uint32_t index = startExpertId_; index < endExpertId_; index++) { + uint32_t i = index - startExpertId_; + if (i > 0) { + outCountLocal.SetValue(i, outCountLocal.GetValue(i - 1) + outCountLocal.GetValue(index)); + } + uint32_t count = statusTensor_.GetValue(index * INT32_NUM_PER_BLOCK + 1); + beginIdx = outCountLocal.GetValue(i) - count; + if constexpr (IsNeedAllgater) { + gatherCount_ += count; + } + if (i == 0) { + preCnt_ = beginIdx; + } + if (isShareExpertRank_) { + if (index < sharedExpertRankNum_) { + beginIdx += count; + continue; + } + } + uint32_t winOffset = index; + if (!isShareExpertRank_) { + if (moeExpertNumPerRank_ > 1) { + winOffset = index % epWorldSize_ * moeExpertNumPerRank_ + index / epWorldSize_; + } + } + GM_ADDR wAddr = (__gm__ uint8_t *)(windowGM_) + winOffset * expertPerSizeOnWin_; + GlobalTensor tokGlobal; + GlobalTensor expandXOutGlobal; + for (uint32_t j = 0; j < count; j++) { + tokGlobal.SetGlobalBuffer((__gm__ ExpandXOutType *)(wAddr + j * hCommuSize_)); + xTmpTensor_ = xQueue_.AllocTensor(); + DataCopy(xTmpTensor_, tokGlobal, axisHCommu_); + xQueue_.EnQue(xTmpTensor_); + xTmpTensor_ = xQueue_.DeQue(); + if constexpr (DynamicQuant || StaticQuant) { + pipe_barrier(PIPE_ALL); + xOutFp32Tensor_ = xTmpTensor_.template ReinterpretCast(); + DataCopyPad(dynamicScalesOutGMTensor_[beginIdx + j], xOutFp32Tensor_[axisH_ / sizeof(float)], + dataCopyParamsFloat); + pipe_barrier(PIPE_ALL); + } + if constexpr (IsNeedAllgater) { + DataCopy(winTpGatherOutGMTensor_[(beginIdx + j) * axisHCommu_], xTmpTensor_, axisHCommu_); + } + expandXOutGlobal.SetGlobalBuffer((__gm__ ExpandXOutType *)(expandXOutGM_) + (beginIdx + j) * axisH_, + axisH_); + DataCopy(expandXOutGlobal, xTmpTensor_, axisH_); + xQueue_.FreeTensor(xTmpTensor_); + } + beginIdx += count; + } + if constexpr (!IsNeedAllgater) { + totalCnt_ = beginIdx; + } + lastCore_ = MIN(totalMoeExpert, aivNum_) - 1; + if constexpr (IsNeedAllgater) { + DataCopyExtParams dataCopyOutParams = {1U, static_cast(sendExpertNum_ * sizeof(int32_t)), 0U, 0U, 0U}; + DataCopyPad(winTpEpCntGMTensor_[startExpertId_], outCountLocal, dataCopyOutParams); + } + DataCopyExtParams dataCopyOutParams = {1U, static_cast(sendExpertNum_ * sizeof(int32_t)), 0U, 0U, 0U}; + GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendCountsOutGM_)); + DataCopyPad(sendCountsGlobal[startExpertId_], outCountLocal, dataCopyOutParams); + PipeBarrier(); +} + +template +__aicore__ inline void CamMoeDistributeDispatch::AllGatherSetStatusAndWait() +{ + pipe_barrier(PIPE_ALL); + if (startExpertId_ >= totalExpertNum_) { + return; + } + GM_ADDR rankGM = (__gm__ uint8_t *)(GetWindStateAddrByRankId(COMM_TP_IDX, tpGatherRankId_) + stateOffset_ * aivId_); + GlobalTensor tpwindowInstatusFp32Tensor_; + tpwindowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)(rankGM)); + statusTensor_(aivId_ * INT32_NUM_PER_BLOCK + 1) = gatherCount_; + statusTensor_(aivId_ * INT32_NUM_PER_BLOCK + 2) = preCnt_; + LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + statusFp32Tensor_(aivId_ * 8) = sumTarget_; + SyncFunc(); + DataCopy(tpwindowInstatusFp32Tensor_, statusFp32Tensor_[aivId_ * 8], UB_ALIGN); + SyncFunc(); + float sumOfFlag = static_cast(-1.0); + rankGM = (__gm__ uint8_t *)(GetWindStateAddrByRankId(COMM_TP_IDX, tpRankId_) + stateOffset_ * aivId_); + tpwindowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)(rankGM)); + while (sumOfFlag != sumTarget_) { + DataCopy(statusFp32Tensor_, tpwindowInstatusFp32Tensor_, UB_ALIGN); + SyncFunc(); + sumOfFlag = statusFp32Tensor_.GetValue(0); + SyncFunc(); + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::AllgatherProcessOut() +{ + if (startExpertId_ >= totalExpertNum_) { + return; + } + GlobalTensor tpwindowInstatusFp32Tensor_; + GM_ADDR rankGM = (__gm__ uint8_t *)(GetWindStateAddrByRankId(COMM_TP_IDX, tpRankId_) + stateOffset_ * aivId_); + tpwindowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)rankGM); + LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + DataCopy(statusFp32Tensor_, tpwindowInstatusFp32Tensor_, UB_ALIGN); + SyncFunc(); + uint32_t coreGatherCount = statusFp32Tensor_.ReinterpretCast().GetValue(1); + uint32_t preCount = statusFp32Tensor_.ReinterpretCast().GetValue(2); + gatherCount_ = coreGatherCount; + preCnt_ = preCount; + GlobalTensor sendCountsGlobal; + GlobalTensor tpGlobal; + + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendCountsOutGM_)); + tpGlobal.SetGlobalBuffer((__gm__ int32_t *)(tpLocalStatusWindowGM_ + TP_STATE_SIZE)); + DataCopyExtParams dataCopyParams = {1U, static_cast(sendExpertNum_ * sizeof(int32_t)), 0U, 0U, 0U}; + DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + tpTmpTensor_ = xQueue_.AllocTensor(); + DataCopyPad(tpTmpTensor_, tpGlobal[startExpertId_], dataCopyParams, copyPadParams); + xQueue_.EnQue(tpTmpTensor_); + tpTmpTensor_ = xQueue_.DeQue(); + DataCopyPad(sendCountsGlobal[epWorldSize_ + startExpertId_], tpTmpTensor_, dataCopyParams); + xQueue_.FreeTensor(tpTmpTensor_); + if (coreGatherCount == 0) { + return; + } + + GlobalTensor tokGlobal; + GlobalTensor expandXOutGlobal; + DataCopyExtParams dataCopyParamsFloat = {1U, sizeof(float), 0U, 0U, 0U}; + for (uint32_t i = 0; i < coreGatherCount; i++) { + tokGlobal.SetGlobalBuffer((__gm__ ExpandXOutType *)(tpLocalWindowGM_ + (preCount + i) * hCommuSize_)); + xTmpTensor_ = xQueue_.AllocTensor(); + DataCopy(xTmpTensor_, tokGlobal, axisHCommu_); + xQueue_.EnQue(xTmpTensor_); + xTmpTensor_ = xQueue_.DeQue(); + expandXOutGlobal.SetGlobalBuffer( + (__gm__ ExpandXOutType *)(expandXOutGM_ + (preCount + totalCnt_ + i) * hOutSize_)); + DataCopy(expandXOutGlobal, xTmpTensor_, axisH_); + if constexpr (StaticQuant || DynamicQuant) { + xOutFp32Tensor_ = xTmpTensor_.template ReinterpretCast(); + DataCopyPad(dynamicScalesOutGMTensor_[preCount + totalCnt_ + i], xOutFp32Tensor_[axisH_ / sizeof(float)], + dataCopyParamsFloat); + } + xQueue_.FreeTensor(xTmpTensor_); + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::UpdataMultiMoeTokenNumsOut() +{ + uint32_t tokenSums = 0; + GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendCountsOutGM_)); + for (uint32_t localMoeIndex = 0; localMoeIndex < moeExpertNumPerRank_; ++localMoeIndex) { + if (localMoeIndex == 0) { + DataCacheCleanAndInvalid( + sendCountsGlobal[epWorldSize_ - 1]); + uint32_t firstMoeCnt = sendCountsGlobal.GetValue(epWorldSize_ - 1); + tokenSums = firstMoeCnt + gatherCount_; + expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenSums); + DataCacheCleanAndInvalid( + expertTokenNumsOutGMTensor_[localMoeIndex]); + } else { + uint32_t preIndex = epWorldSize_ * (localMoeIndex - 1) + epWorldSize_ - 1; + uint32_t curIndex = epWorldSize_ * localMoeIndex + epWorldSize_ - 1; + DataCacheCleanAndInvalid( + sendCountsGlobal[preIndex]); + DataCacheCleanAndInvalid( + sendCountsGlobal[curIndex]); + uint32_t preMoeIndexCnt = sendCountsGlobal.GetValue(preIndex); + uint32_t curMoeIndexCnt = sendCountsGlobal.GetValue(curIndex); + tokenSums = + ((expertTokenNumsType_ == 0) ? tokenSums : 0) + (curMoeIndexCnt - preMoeIndexCnt) + gatherCount_; + expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenSums); + DataCacheCleanAndInvalid( + expertTokenNumsOutGMTensor_[localMoeIndex]); + } + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::UpdataTokenNumsOut() +{ + if (!isShareExpertRank_ && moeExpertNumPerRank_ > 1) { + SyncAll(); + if (aivId_ != lastCore_) return; + SyncFunc(); + UpdataMultiMoeTokenNumsOut(); + } else { + if (aivId_ != lastCore_) return; + uint32_t tokenNum = 0; + + tokenNum = totalCnt_; + if constexpr (IsNeedAllgater) { + tokenNum += preCnt_; + tokenNum += gatherCount_; + } + expertTokenNumsOutGMTensor_.SetValue(0, tokenNum); + DataCacheCleanAndInvalid( + expertTokenNumsOutGMTensor_); + } + + if constexpr (IsNeedAllgater) { + GlobalTensor sendTpCountsGlobal; + sendTpCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendTpCountOutGM_)); + sendTpCountsGlobal.SetValue(tpRankId_, totalCnt_); + sendTpCountsGlobal.SetValue(tpGatherRankId_, gatherCount_ + preCnt_); + DataCacheCleanAndInvalid(sendTpCountsGlobal); + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::Process() +{ + if ASCEND_IS_AIV { + AlltoAllDispatch(); + SetStatus(); + WaitDispatch(); + LocalWindowCopy(); + if constexpr (IsNeedAllgater) { + AllGatherSetStatusAndWait(); + AllgatherProcessOut(); + } + UpdataTokenNumsOut(); + } +} + +} // namespace MoeDistributeDispatchImpl +#endif // CAM_MOE_DISTRIBUTE_DISPATCH_H diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/arch/arch.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/arch/arch.hpp new file mode 100644 index 00000000..65363c75 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/arch/arch.hpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_ARCH_ARCH_HPP +#define CATLASS_ARCH_ARCH_HPP + +#include "catlass/catlass.hpp" + +namespace Catlass::Arch { + +struct AtlasA2 { + static constexpr uint32_t BIAS_SIZE = 1024; + static constexpr uint32_t FIXBUF_SIZE = 7 * 1024; + static constexpr uint32_t UB_SIZE = 192 * 1024; + static constexpr uint32_t L1_SIZE = 512 * 1024; + static constexpr uint32_t L0A_SIZE = 64 * 1024; + static constexpr uint32_t L0B_SIZE = 64 * 1024; + static constexpr uint32_t L0C_SIZE = 128 * 1024; +}; + +template +using PositionType = std::integral_constant; + +using PositionGM = PositionType; +using PositionL1 = PositionType; +using PositionL0A = PositionType; +using PositionL0B = PositionType; +using PositionL0C = PositionType; +using PositionUB = PositionType; + +} // namespace Catlass::Arch + +#endif // CATLASS_ARCH_ARCH_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/arch/cross_core_sync.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/arch/cross_core_sync.hpp new file mode 100644 index 00000000..7d8363cc --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/arch/cross_core_sync.hpp @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_ARCH_CROSS_CORE_SYNC_HPP +#define CATLASS_ARCH_CROSS_CORE_SYNC_HPP + +#include "catlass/catlass.hpp" + +namespace Catlass::Arch { + +constexpr uint32_t MAX_REVERSE_DEPTH = 16; + +using FlagID = uint16_t; +constexpr FlagID AIV_INTER_BLOCK_BARRIER = 8; +constexpr FlagID AIC_INTER_BLOCK_BARRIER = 9; +constexpr FlagID AIV_INTER_SUBBLOCK_BARRIER = 10; +constexpr FlagID FFTS_MAX_FLAG = 7; + +struct CrossCoreFlag { + CATLASS_DEVICE + CrossCoreFlag() : id(0) {} + + CATLASS_DEVICE + CrossCoreFlag(FlagID id) : id(id) {} + + FlagID id; +}; + +template +struct CrossCoreFlagWithReverse { + CATLASS_DEVICE + CrossCoreFlagWithReverse() : id(0), reverseId(0) {} + + CATLASS_DEVICE + CrossCoreFlagWithReverse(FlagID id, FlagID reverseId) : id(id), reverseId(reverseId) {} + + FlagID id; + FlagID reverseId; + uint32_t count{0}; +}; + +template +struct BarrierFlag { + static_assert(MODE != MODE, "Unsupported cross core barrier flag, can not find the specialization."); +}; + +template <> +struct BarrierFlag<0x0, AscendC::AIV> { + static constexpr FlagID ID = AIV_INTER_BLOCK_BARRIER; +}; + +template <> +struct BarrierFlag<0x0, AscendC::AIC> { + static constexpr FlagID ID = AIC_INTER_BLOCK_BARRIER; +}; + +template <> +struct BarrierFlag<0x1, AscendC::AIV> { + static constexpr FlagID ID = AIV_INTER_SUBBLOCK_BARRIER; +}; + +template +CATLASS_DEVICE void CrossCoreBarrier() +{ + constexpr FlagID flagId = BarrierFlag::ID; + AscendC::CrossCoreSetFlag(flagId); + AscendC::CrossCoreWaitFlag(flagId); +} + +template +CATLASS_DEVICE void CrossCoreSetFlag(CrossCoreFlag &flag) +{ + AscendC::CrossCoreSetFlag(flag.id); +} + +CATLASS_DEVICE +void CrossCoreWaitFlag(CrossCoreFlag &flag) +{ + AscendC::CrossCoreWaitFlag(flag.id); +} + +template +CATLASS_DEVICE void CrossCoreSetFlagWithReverse(CrossCoreFlagWithReverse &flag) +{ + AscendC::CrossCoreSetFlag(flag.id); + if (++flag.count >= REVERSE_DEPTH) { + AscendC::CrossCoreWaitFlag(flag.reverseId); + flag.count = 0; + } +} + +template +CATLASS_DEVICE void CrossCoreWaitFlagWithReverse(CrossCoreFlagWithReverse &flag) +{ + AscendC::CrossCoreWaitFlag(flag.id); + if (++flag.count >= REVERSE_DEPTH) { + AscendC::CrossCoreSetFlag(flag.reverseId); + flag.count = 0; + } +} + +} // namespace Catlass::Arch + +#endif // CATLASS_ARCH_CROSS_CORE_SYNC_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/arch/local_tensor_buffer.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/arch/local_tensor_buffer.hpp new file mode 100644 index 00000000..85029424 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/arch/local_tensor_buffer.hpp @@ -0,0 +1,229 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef INCLUDE_CATLASS_ARCH_MEMORY_H +#define INCLUDE_CATLASS_ARCH_MEMORY_H + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" + +namespace Catlass::Arch { + +struct LocalTensorBufferBase { +public: + template + CATLASS_DEVICE AscendC::LocalTensor GetBufferByByte(const uint32_t offset) const + { + return tensor[offset].template ReinterpretCast(); + } + +protected: + CATLASS_DEVICE + LocalTensorBufferBase() = default; + + AscendC::LocalTensor tensor; +}; + +template +struct LocalTensorBuffer { + static_assert(DEPENDENT_FALSE, "Unsupported local tensor buffer, can not find the specialization."); +}; + +/// Partial specialization for TPosition::A1 +template +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::A1; + + CATLASS_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufA1; + GetTPipePtr()->InitBuffer(tbufA1, ArchTag::L1_SIZE); + tensor = tbufA1.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for TPosition::A2 +template +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::A2; + + CATLASS_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufA2; + GetTPipePtr()->InitBuffer(tbufA2, ArchTag::L0A_SIZE); + tensor = tbufA2.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for TPosition::B1 +template +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::B1; + + CATLASS_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufB1; + GetTPipePtr()->InitBuffer(tbufB1, ArchTag::L1_SIZE); + tensor = tbufB1.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for AtlasA2, TPosition::B2 +template +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::B2; + + CATLASS_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufB2; + GetTPipePtr()->InitBuffer(tbufB2, ArchTag::L0B_SIZE); + tensor = tbufB2.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for AtlasA2, TPosition::C1 +template <> +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + using ArchTag = Arch::AtlasA2; + static constexpr AscendC::TPosition Position = AscendC::TPosition::C1; + + CATLASS_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufC1; + GetTPipePtr()->InitBuffer(tbufC1, ArchTag::L1_SIZE); + tensor = tbufC1.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for AtlasA2, TPosition::C2 +template <> +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + using ArchTag = Arch::AtlasA2; + static constexpr AscendC::TPosition Position = AscendC::TPosition::C2; + + CATLASS_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufC2; + GetTPipePtr()->InitBuffer(tbufC2, ArchTag::BIAS_SIZE); + tensor = tbufC2.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for TPosition::CO1 +template +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::CO1; + + CATLASS_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufCO1; + GetTPipePtr()->InitBuffer(tbufCO1, ArchTag::L0C_SIZE); + tensor = tbufCO1.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for AtlasA2, TPosition::C2PIPE2GM +template <> +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + using ArchTag = Arch::AtlasA2; + static constexpr AscendC::TPosition Position = AscendC::TPosition::C2PIPE2GM; + + CATLASS_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufC2PIPE2GM; + GetTPipePtr()->InitBuffer(tbufC2PIPE2GM, ArchTag::FIXBUF_SIZE); + tensor = tbufC2PIPE2GM.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for TPosition::VECIN +template +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::VECIN; + + CATLASS_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufVECIN; + GetTPipePtr()->InitBuffer(tbufVECIN, ArchTag::UB_SIZE); + tensor = tbufVECIN.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for TPosition::VECOUT +template +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::VECOUT; + + CATLASS_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufVECOUT; + GetTPipePtr()->InitBuffer(tbufVECOUT, ArchTag::UB_SIZE); + tensor = tbufVECOUT.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for TPosition::VECCALC +template +struct LocalTensorBuffer : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::VECCALC; + + CATLASS_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufVECCALC; + GetTPipePtr()->InitBuffer(tbufVECCALC, ArchTag::UB_SIZE); + tensor = tbufVECCALC.Get(); + } +}; + +} // namespace Catlass::Arch + +#endif // INCLUDE_CATLASS_ARCH_MEMORY_H diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/arch/resource.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/arch/resource.hpp new file mode 100644 index 00000000..ff2e03da --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/arch/resource.hpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef INCLUDE_CATLASS_ARCH_RESOURCE_HPP +#define INCLUDE_CATLASS_ARCH_RESOURCE_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/local_tensor_buffer.hpp" + +namespace Catlass::Arch { + +template +struct Resource { +public: + AscendC::TPipe pipe; + + LocalTensorBuffer l1Buf; + LocalTensorBuffer l0ABuf; + LocalTensorBuffer l0BBuf; + LocalTensorBuffer btBuf; + LocalTensorBuffer l0CBuf; + LocalTensorBuffer ubBuf; + + CATLASS_DEVICE + Resource() + { + // The initialization of AscendC::Tpipe will insert some synchronization interfaces, + // which may conflict with the usage by users. Therefore, the "destroy" interface is used for releasing. + pipe.Destroy(); + } +}; + +} // namespace Catlass::Arch + +#endif // INCLUDE_CATLASS_ARCH_RESOURCE_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/catlass.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/catlass.hpp new file mode 100644 index 00000000..014257f3 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/catlass.hpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_CATLASS_HPP +#define CATLASS_CATLASS_HPP + +#include + +#if defined(__CCE__) +#include +#endif + +#include "catlass/detail/alignment.hpp" +#include "catlass/detail/dependent_false.hpp" +#include "catlass/detail/macros.hpp" + +namespace Catlass { + +constexpr uint32_t BYTE_PER_C0 = 32; +constexpr uint32_t BYTE_PER_C2 = 64; +constexpr uint32_t C0_NUM_PER_FRACTAL = 16; +constexpr uint32_t BYTE_PER_FRACTAL = BYTE_PER_C0 * C0_NUM_PER_FRACTAL; + +constexpr uint32_t BYTE_PER_BLK = 32; +constexpr uint32_t BLK_NUM_PER_VECTOR_FRACTAL = 8; +constexpr uint32_t BYTE_PER_VECTOR_FRACTAL = BYTE_PER_BLK * BLK_NUM_PER_VECTOR_FRACTAL; + +constexpr uint64_t L2_OFFSET = 0; +constexpr uint32_t STRIDE_LIMIT = 65536; + +} // namespace Catlass + +#endif // CATLASS_CATLASS_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/block/block_conv.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/block/block_conv.hpp new file mode 100644 index 00000000..06a35f3b --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/block/block_conv.hpp @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_CONV_BLOCK_BLOCK_CONV_HPP +#define CATLASS_CONV_BLOCK_BLOCK_CONV_HPP + +#include "catlass/catlass.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +namespace Catlass::Conv::Block { + +template , + class TileMmad = Gemm::Tile::TileMmad > +struct BlockConv { + static_assert(DEPENDENT_FALSE, "BlockConv is not implemented for this DispatchPolicy"); +}; +} // namespace Catlass::Conv::Block + +#include "catlass/conv/block/block_conv3d_pingpong_bias.hpp" + +#endif // CATLASS_CONV_BLOCK_BLOCK_CONV_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/block/block_conv3d_pingpong_bias.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/block/block_conv3d_pingpong_bias.hpp new file mode 100644 index 00000000..ce90f2d5 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/block/block_conv3d_pingpong_bias.hpp @@ -0,0 +1,690 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_CONV_BLOCK_BLOCK_CONV3D_PINGPONG_BIAS_HPP +#define CATLASS_CONV_BLOCK_BLOCK_CONV3D_PINGPONG_BIAS_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/conv_coord.hpp" +#include "catlass/conv/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +namespace Catlass::Conv::Block { + +template +struct BlockConv< + ConvAtlasA2Pingpong, + CoreTileShape_, FmapL1TileShape_, FilterL1TileShape_, L0TileShape_, FmapType_, FilterType_, OutType_, BiasType_, + TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = + ConvAtlasA2Pingpong; + using ArchTag = typename DispatchPolicy::ArchTag; + using CoreTileShape = CoreTileShape_; + using FmapL1TileShape = FmapL1TileShape_; + using FilterL1TileShape = FilterL1TileShape_; + using L0TileShape = L0TileShape_; + using ElementFmap = typename FmapType_::Element; + using LayoutFmap = typename FmapType_::Layout; + using ElementFilter = typename FilterType_::Element; + using LayoutFilter = typename FilterType_::Layout; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using ElementOut = typename OutType_::Element; + using LayoutOut = typename OutType_::Layout; + using ElementBias = typename BiasType_::Element; + using LayoutBias = typename BiasType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyGmToL1Bias = typename TileCopy_::CopyGmToL1Bias; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using CopyL1ToBT = typename TileCopy_::CopyL1ToBT; + + using LayoutFmapInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutFilterInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr uint32_t L1A_STAGES = DispatchPolicy::L1A_STAGES; + static constexpr uint32_t L1B_STAGES = DispatchPolicy::L1B_STAGES; + static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES; + static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES; + static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES; + static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = ArchTag::L0A_SIZE / L0A_STAGES; + static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = ArchTag::L0B_SIZE / L0B_STAGES; + static constexpr uint32_t L0C_PINGPONG_BUF_SIZE = ArchTag::L0C_SIZE / L0C_STAGES; + static constexpr uint32_t C0_SIZE = 32; + static constexpr uint32_t BLOCK_L0_M = 16; + static constexpr uint32_t BLOCK_L0_N = 16; + static constexpr uint32_t RIGHT_MOVE_8 = 8; + static constexpr uint32_t PAD_SIZE = 4; + static constexpr uint32_t PAD_IDX_T = 2; + static constexpr uint32_t PAD_IDX_B = 3; + static constexpr uint32_t PAD_IDX_L = 0; + static constexpr uint32_t PAD_IDX_R = 1; + static constexpr uint32_t BLOCK_SIZE = 512; + + // Check PingPong + static_assert(L1A_STAGES == 1, "L1A PingPong must be 1!"); + static_assert(L1B_STAGES == 1, "L1A PingPong must be 1!"); + static_assert(L0C_STAGES == 1, "L0C PingPong must be 1!"); + static_assert(L0A_STAGES == 2, "L0A PingPong must be 2!"); + static_assert(L0B_STAGES == 2, "L0B PingPong must be 2!"); + + ///// Construct 进行initBuffer + CATLASS_DEVICE + BlockConv(Arch::Resource &resource, Conv3dParams const &conv3dParams_, uint32_t l1BufAddrStart = 0) + : conv3dParams(conv3dParams_) + { + copyL1ToL0A = CopyL1ToL0A::MakeCopyL1ToL0A(conv3dParams.sW(), conv3dParams.sH(), conv3dParams.kw(), + conv3dParams.kh(), conv3dParams.dW(), conv3dParams.dH()); + uint64_t bl1Spacesize = FilterL1TileShape::Kd * FilterL1TileShape::Ci1 * conv3dParams.khkwcin0() * + FilterL1TileShape::nBL1 * sizeof(ElementFilter); + uint64_t hoAL1Max = FmapL1TileShape::mAL1 / conv3dParams.wo() + 2; + uint64_t hiAL1Max = (hoAL1Max - 1) * conv3dParams.sH() + conv3dParams.dilatedKernelH(); + hiAL1Max = hiAL1Max > conv3dParams.hi() ? conv3dParams.hi() : hiAL1Max; + uint64_t al1Spacesize = + FmapL1TileShape::Kd * FmapL1TileShape::Ci1 * hiAL1Max * conv3dParams.wicin0() * sizeof(ElementFmap); + + for (uint32_t i = 0; i < L0A_STAGES; i++) { + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_PINGPONG_BUF_SIZE * i); + l0AEventList[i] = i; + AscendC::SetFlag(l0AEventList[i]); + } + for (uint32_t i = 0; i < L0B_STAGES; i++) { + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_PINGPONG_BUF_SIZE * i); + l0BEventList[i] = i + L0A_STAGES; + AscendC::SetFlag(l0BEventList[i]); + } + for (uint32_t i = 0; i < L0C_STAGES; i++) { + l0CTensorList[i] = resource.l0CBuf.template GetBufferByByte(L0C_PINGPONG_BUF_SIZE * i); + l0CEventList[i] = i; + AscendC::SetFlag(l0CEventList[i]); + } + + uint32_t l1AOffset = l1BufAddrStart; + for (uint32_t i = 0; i < L1A_STAGES; i++) { + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + al1Spacesize * i); + l1AEventList[i] = i; + AscendC::SetFlag(l1AEventList[i]); + } + uint32_t l1BOffset = l1BufAddrStart + al1Spacesize * L1A_STAGES; + for (uint32_t i = 0; i < L1B_STAGES; i++) { + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + bl1Spacesize * i); + l1BEventList[i] = i + L1A_STAGES; + AscendC::SetFlag(l1BEventList[i]); + } + uint32_t l1BiasOffset = l1BufAddrStart + al1Spacesize * L1A_STAGES + bl1Spacesize * L1B_STAGES; + l1BiasTensor = resource.l1Buf.template GetBufferByByte(l1BiasOffset); + AscendC::SetFlag(L1A_STAGES + L1B_STAGES); + l0BiasTensor = resource.btBuf.template GetBufferByByte(0); + AscendC::SetFlag(L0A_STAGES + L0B_STAGES); + } + + /// Destructor + CATLASS_DEVICE + ~BlockConv() + { + for (uint32_t i = 0; i < L0A_STAGES; i++) { + AscendC::WaitFlag(l0AEventList[i]); + } + for (uint32_t i = 0; i < L0B_STAGES; i++) { + AscendC::WaitFlag(l0BEventList[i]); + } + for (uint32_t i = 0; i < L0C_STAGES; i++) { + AscendC::WaitFlag(l0CEventList[i]); + } + for (uint32_t i = 0; i < L1A_STAGES; i++) { + AscendC::WaitFlag(l1AEventList[i]); + } + for (uint32_t i = 0; i < L1B_STAGES; i++) { + AscendC::WaitFlag(l1BEventList[i]); + } + AscendC::WaitFlag(L1A_STAGES + L1B_STAGES); + AscendC::WaitFlag(L0A_STAGES + L0B_STAGES); + } + + // Perform a block-scoped conv3d + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &fmapGm, LayoutFmap const &layoutFmap, + AscendC::GlobalTensor const &filterGm, LayoutFilter const &layoutFilter, + AscendC::GlobalTensor const &outGm, LayoutFmap const &layoutOut, + AscendC::GlobalTensor const &biasGm, Conv3d6HdCoord const &actualBlockShape, + Conv3d6HdCoord const &actualIdxStartFmap) + { + // Initialization of the loop parameter in the K direction + iterParams.ddr2l0LoopK = CeilDiv(conv3dParams.alignCinKhKwKd(), L0TileShape::kL0); + iterParams.maxKL0Iter = iterParams.ddr2l0LoopK - 1; + iterParams.kL0Tail = conv3dParams.alignCinKhKwKd() % L0TileShape::kL0; + iterParams.kL0Tail = iterParams.kL0Tail == 0 ? L0TileShape::kL0 : iterParams.kL0Tail; + + // The k-axis loop iteration parameters of the B-matrix + iterParams.maxKBL1Iter = CeilDiv(conv3dParams.kdcin1(), FilterL1TileShape::Kd * FilterL1TileShape::Ci1) - 1; + iterParams.multiKBL1 = + CeilDiv(FilterL1TileShape::Kd * FilterL1TileShape::Ci1 * conv3dParams.khkwcin0(), L0TileShape::kL0); + iterParams.kBL1fullload = conv3dParams.kdcin1() == FilterL1TileShape::Kd * FilterL1TileShape::Ci1; + uint32_t kBL1TailCheck = + conv3dParams.alignCinKhKwKd() % (FilterL1TileShape::Kd * FilterL1TileShape::Ci1 * conv3dParams.khkwcin0()); + iterParams.kBL1Tail = kBL1TailCheck == 0 + ? FilterL1TileShape::Kd * FilterL1TileShape::Ci1 * conv3dParams.khkwcin0() + : kBL1TailCheck; + + // The k-axis loop iteration parameters of matrix A + iterParams.maxKAL1Iter = CeilDiv(conv3dParams.kdcin1(), FmapL1TileShape::Kd * FmapL1TileShape::Ci1) - 1; + iterParams.multiKAL1 = + CeilDiv(FmapL1TileShape::Kd * FmapL1TileShape::Ci1 * conv3dParams.khkwcin0(), L0TileShape::kL0); + iterParams.kAL1fullload = conv3dParams.kdcin1() == FmapL1TileShape::Kd * FmapL1TileShape::Ci1; + uint32_t kAL1TailCheck = + conv3dParams.alignCinKhKwKd() % (FmapL1TileShape::Kd * FmapL1TileShape::Ci1 * conv3dParams.khkwcin0()); + iterParams.kAL1Tail = + kAL1TailCheck == 0 ? FmapL1TileShape::Kd * FmapL1TileShape::Ci1 * conv3dParams.khkwcin0() : kAL1TailCheck; + + // Loop parameters in the M direction + iterParams.mAL1Tail = actualBlockShape.hw() % FmapL1TileShape::mAL1; + iterParams.mAL1Tail = iterParams.mAL1Tail == 0 ? FmapL1TileShape::mAL1 : iterParams.mAL1Tail; + uint32_t mAL1DivmL0 = CeilDiv(FmapL1TileShape::mAL1, L0TileShape::mL0); + uint32_t ddr2l1LoopM = CeilDiv(actualBlockShape.hw(), FmapL1TileShape::mAL1); + iterParams.maxMAL1Iter = ddr2l1LoopM - 1; + iterParams.mAL0Tail = iterParams.mAL1Tail % L0TileShape::mL0; + iterParams.mAL0Tail = iterParams.mAL0Tail == 0 ? L0TileShape::mL0 : iterParams.mAL0Tail; + iterParams.l12l0LoopM = CeilDiv(FmapL1TileShape::mAL1, L0TileShape::mL0); + iterParams.maxML0Iter = iterParams.l12l0LoopM - 1; + + // Loop parameters in the Cout direction + iterParams.maxNBL1Iter = CeilDiv(actualBlockShape.c1() * conv3dParams.cout0(), FilterL1TileShape::nBL1) - 1; + iterParams.nBL1Tail = (actualBlockShape.c1() * conv3dParams.cout0()) % FilterL1TileShape::nBL1; + iterParams.nBL1Tail = iterParams.nBL1Tail == 0 ? FilterL1TileShape::nBL1 : iterParams.nBL1Tail; + uint32_t nBL1DivnL0 = CeilDiv(FilterL1TileShape::nBL1, L0TileShape::nL0); + iterParams.nBL1TailAlign = CeilDiv(iterParams.nBL1Tail, BLOCK_L0_N) * BLOCK_L0_N; + iterParams.nL0Tail = iterParams.nBL1Tail % L0TileShape::nL0; + iterParams.nL0Tail = iterParams.nL0Tail == 0 ? L0TileShape::nL0 : iterParams.nL0Tail; + iterParams.ddr2l1LoopN = iterParams.maxNBL1Iter + 1; + iterParams.l12l0LoopN = nBL1DivnL0; + iterParams.maxNL0Iter = iterParams.l12l0LoopN - 1; + + // Loop parameter in the D direction + iterParams.ddr2l1LoopD = actualBlockShape.d(); + + // The starting position of the input + iterParams.diStartPos = actualIdxStartFmap.d(); + iterParams.hwStartPos = actualIdxStartFmap.hw(); + + // Start the batch iterate + for (uint32_t batchIter = 0; batchIter < actualBlockShape.n(); ++batchIter) { + auto gmBatchFmap = fmapGm[batchIter * conv3dParams.fmapOneBatchSize()]; + auto gmBatchOut = outGm[batchIter * conv3dParams.outputOneBatchSize()]; + while (true) { + // The parameters used need to be reinitialized in the first iteration + if (iterParams.isFirstIterate) { + iterParams.nBL0Iter = 0; + iterParams.mAL0Iter = 0; + iterParams.mAL1Iter = 0; + iterParams.nBL1Iter = 0; + iterParams.dOutIter = 0; + iterParams.loadAL1Flag = true; + iterParams.loadBL1Flag = true; + iterParams.loadAL0Flag = true; + iterParams.loadBL0Flag = true; + iterParams.isFirstIterate = false; + if (L0TileShape::mL0 % conv3dParams.wo() == 0) { + iterParams.mL0IsDivisibleByWo = true; + } + } else { + // From N to M + iterParams.nBL0Iter++; + if (iterParams.nBL0Iter == iterParams.l12l0LoopN) { + iterParams.nBL0Iter = 0; + iterParams.mAL0Iter++; + } + if (iterParams.mAL0Iter == iterParams.l12l0LoopM) { + iterParams.mAL0Iter = 0; + iterParams.nBL1Iter++; + iterParams.loadBL1Flag = true; + } + if (iterParams.nBL1Iter == iterParams.ddr2l1LoopN) { + iterParams.nBL1Iter = 0; + iterParams.mAL1Iter++; + iterParams.loadAL1Flag = true; + } + if (iterParams.mAL1Iter == ddr2l1LoopM) { + iterParams.mAL1Iter = 0; + iterParams.dOutIter++; + } + if (iterParams.dOutIter == iterParams.ddr2l1LoopD) { + break; + } + } + // Refresh the cycle round + iterParams.l12l0LoopM = iterParams.mAL1Iter == iterParams.maxMAL1Iter + ? CeilDiv(iterParams.mAL1Tail, L0TileShape::mL0) + : mAL1DivmL0; + iterParams.maxML0Iter = iterParams.l12l0LoopM - 1; + iterParams.l12l0LoopN = iterParams.nBL1Iter == iterParams.maxNBL1Iter + ? CeilDiv(iterParams.nBL1Tail, L0TileShape::nL0) + : nBL1DivnL0; + iterParams.maxNL0Iter = iterParams.l12l0LoopN - 1; + // Start the K-axis iterate + uint32_t n = + (iterParams.nBL1Iter == iterParams.maxNBL1Iter && iterParams.nBL0Iter == iterParams.maxNL0Iter) + ? iterParams.nL0Tail + : L0TileShape::nL0; + uint32_t m = + (iterParams.mAL1Iter == iterParams.maxMAL1Iter && iterParams.mAL0Iter == iterParams.maxML0Iter) + ? iterParams.mAL0Tail + : L0TileShape::mL0; + + tileParams.l0CurrentM = CeilDiv(m, BLOCK_L0_M) * BLOCK_L0_M; + tileParams.l0CurrentN = CeilDiv(n, BLOCK_L0_N) * BLOCK_L0_N; + + uint32_t biasGmOffset = + iterParams.nBL1Iter * FilterL1TileShape::nBL1 + iterParams.nBL0Iter * L0TileShape::nL0; + + auto layoutTileBias = layout::VectorLayout(actualBlockShape.c1() * conv3dParams.n0()); + auto layoutBiasInL1 = layout::VectorLayout(tileParams.l0CurrentN); + auto l0BiasTile = l0BiasTensor; + AscendC::WaitFlag(L1A_STAGES + L1B_STAGES); + copyGmToL1Bias(l1BiasTensor, biasGm[biasGmOffset], layoutBiasInL1, layoutTileBias); + AscendC::SetFlag(L1A_STAGES + L1B_STAGES); + AscendC::WaitFlag(L1A_STAGES + L1B_STAGES); + auto layoutBiasInL0 = layout::VectorLayout(tileParams.l0CurrentN); + AscendC::WaitFlag(L0A_STAGES + L0B_STAGES); + copyL1ToBT(l0BiasTile, l1BiasTensor, layoutBiasInL0, layoutBiasInL1); + AscendC::SetFlag(L0A_STAGES + L0B_STAGES); + AscendC::WaitFlag(L0A_STAGES + L0B_STAGES); + AscendC::SetFlag(L1A_STAGES + L1B_STAGES); + iterParams.kIter = 0; + uint16_t isOdd = 0; + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::WaitFlag(EVENT_ID0); + } + + while (iterParams.kIter < iterParams.ddr2l0LoopK) { + if (iterParams.loadAL1Flag || + (!iterParams.kAL1fullload && iterParams.kIter % iterParams.multiKAL1 == 0)) { + AscendC::PipeBarrier(); + AscendC::SetFlag(l1AEventList[l1ListId]); + AscendC::WaitFlag(l1AEventList[l1ListId]); + LoadAL1Process(gmBatchFmap, iterParams.kIter / iterParams.multiKAL1, layoutFmap); + AscendC::SetFlag(l1AEventList[l1ListId]); + AscendC::WaitFlag(l1AEventList[l1ListId]); + } + if (iterParams.loadBL1Flag || + (!iterParams.kBL1fullload && iterParams.kIter % iterParams.multiKBL1 == 0)) { + AscendC::PipeBarrier(); + AscendC::SetFlag(l1BEventList[l1ListId]); + AscendC::WaitFlag(l1BEventList[l1ListId]); + LoadBL1Process(filterGm, iterParams.kIter / iterParams.multiKBL1, layoutFilter); + AscendC::SetFlag(l1BEventList[l1ListId]); + AscendC::WaitFlag(l1BEventList[l1ListId]); + } + ReduceKL0AL0BPingPong(isOdd); + iterParams.kIter++; + isOdd = iterParams.kIter & 0x1; + } + AscendC::SetFlag(L0A_STAGES + L0B_STAGES); + iterParams.kIter = 0; + auto layoutCInL0 = + LayoutCInL0::MakeLayoutInL0C(MakeCoord(m, CeilDiv(n, conv3dParams.cout0()) * conv3dParams.cout0())); + LayoutFmap layoutOutGm = + layoutOut.GetTileLayout(MakeCoord((uint32_t)1, conv3dParams.dout(), conv3dParams.cout1(), + conv3dParams.ho(), conv3dParams.wo(), conv3dParams.cout0())); + uint32_t cout1L1Idx = + (FilterL1TileShape::nBL1 * iterParams.nBL1Iter + L0TileShape::nL0 * iterParams.nBL0Iter) / + conv3dParams.cout0(); + uint32_t howoIdx = FmapL1TileShape::mAL1 * iterParams.mAL1Iter + L0TileShape::mL0 * iterParams.mAL0Iter; + Conv3d6HdCoord gmTileOutOffset{0, iterParams.dOutIter, cout1L1Idx, howoIdx}; + auto gmTileOut = gmBatchOut[layoutOut.GetOffset(gmTileOutOffset)]; + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + copyL0CToGm(gmTileOut, l0CTensorList[l0cListId], layoutOutGm, layoutCInL0); + AscendC::SetFlag(EVENT_ID0); + } else { + copyL0CToGm(gmTileOut, l0CTensorList[l0cListId], layoutOutGm, layoutCInL0, 0b11); + } + } + iterParams.isFirstIterate = true; + iterParams.nBL0Iter = 0; + iterParams.nBL1Iter = 0; + } + } + +protected: + struct IterParams { + uint8_t isFirstIterate = true; + uint8_t loadAL1Flag = true; + uint8_t loadBL1Flag = true; + uint8_t loadAL0Flag = true; + uint8_t loadBL0Flag = true; + uint8_t kAL1fullload = false; + uint8_t kBL1fullload = false; + uint8_t biasFullLoadFlag = false; + uint8_t mL0IsDivisibleByWo = false; + + uint8_t isGroupOptDimTail = false; + + uint32_t kAL1Iter = 0; + uint32_t kBL1Iter = 0; + uint32_t mAL1Iter = 0; + uint32_t nBL1Iter = 0; + uint32_t dOutIter = 0; + uint32_t kIter = 0; + uint32_t kAL0Iter = 0; + uint32_t kBL0Iter = 0; + uint32_t mAL0Iter = 0; + uint32_t nBL0Iter = 0; + uint32_t groupOptIter = 0; + + uint32_t maxKAL1Iter = 0; + uint32_t maxMAL1Iter = 0; + uint32_t maxNBL1Iter = 0; + uint32_t maxKBL1Iter = 0; + uint32_t maxNL0Iter = 0; + uint32_t maxML0Iter = 0; + uint32_t maxKL0Iter = 0; + uint32_t maxDOutIter = 0; + uint32_t maxGroupOptIter = 0; + + uint32_t ddr2l1LoopN = 0; + uint32_t l12l0LoopN = 0; + uint32_t ddr2l1LoopD = 0; + uint32_t l12l0LoopM = 0; + uint32_t ddr2l0LoopK = 0; + + uint32_t kL0Tail = 0; + uint32_t kAL1Tail = 0; + uint32_t kBL1Tail = 0; + uint32_t mAL1Tail = 0; + uint32_t mAL0Tail = 0; + uint32_t nL0Tail = 0; + uint32_t nBL1Tail = 0; + uint32_t multiKAL1 = 1; + uint32_t multiKBL1 = 1; + + uint32_t hwStartPos = 0; + uint32_t diStartPos = 0; + + uint32_t orgCoAlignK0 = 0; + uint32_t orgCoAlignN0 = 0; + uint32_t nBL1TailAlign = 0; + + bool aL1IsFullPad = false; + + CATLASS_DEVICE + IterParams() = default; + }; + + struct TileParams { + uint32_t l0CurrentM = 0; + uint32_t l0CurrentN = 0; + uint32_t l0NumN = 0; + + CATLASS_DEVICE + TileParams() = default; + }; + + Conv3dParams conv3dParams; + IterParams iterParams; + TileParams tileParams; + + AscendC::LocalTensor l1ATensorList[L1A_STAGES]; + AscendC::LocalTensor l1BTensorList[L1B_STAGES]; + AscendC::LocalTensor l0ATensorList[L0A_STAGES]; + AscendC::LocalTensor l0BTensorList[L0B_STAGES]; + AscendC::LocalTensor l0CTensorList[L0C_STAGES]; + AscendC::LocalTensor l1BiasTensor; + AscendC::LocalTensor l0BiasTensor; + + LayoutFmapInL1 layoutFmapInL1; + LayoutFilterInL1 layoutFilterInL1; + + // Multi-stage event id list + int32_t l1AEventList[L1A_STAGES]; + int32_t l1BEventList[L1B_STAGES]; + int32_t l0AEventList[L0A_STAGES]; + int32_t l0BEventList[L0B_STAGES]; + int32_t l0CEventList[L0C_STAGES]; + + // The id of current stage + uint32_t l1ListId{0}; + uint32_t l0cListId{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyGmToL1Bias copyGmToL1Bias; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL1ToBT copyL1ToBT; + CopyL0CToGm copyL0CToGm; + + CATLASS_DEVICE + void LoadAL1Process(AscendC::GlobalTensor const &gmBatchFmap, uint32_t kAL1Iter, + LayoutFmap const &layoutFmap) + { + iterParams.kAL1Iter = kAL1Iter; + uint32_t currentML1 = + iterParams.mAL1Iter == iterParams.maxMAL1Iter ? iterParams.mAL1Tail : FmapL1TileShape::mAL1; + uint32_t currentM = iterParams.hwStartPos + iterParams.mAL1Iter * FmapL1TileShape::mAL1; + uint32_t hoStartIdx = currentM / conv3dParams.wo(); + uint32_t hoEndIdx = CeilDiv(currentM + currentML1, conv3dParams.wo()); + uint32_t orgHiLoadL1 = ((hoEndIdx - hoStartIdx) - 1) * conv3dParams.sH() + conv3dParams.dilatedKernelH(); + uint32_t tmpCurCoreHiStartIdx = (iterParams.hwStartPos / conv3dParams.wo()) * conv3dParams.sH(); + uint32_t curCoreHiStartIdx = + tmpCurCoreHiStartIdx <= conv3dParams.padtop() ? 0 : tmpCurCoreHiStartIdx - conv3dParams.padtop(); + uint32_t currentCin1LoadL1 = iterParams.kAL1Iter * FmapL1TileShape::Kd * FmapL1TileShape::Ci1; + uint32_t kAL1Tmp = iterParams.kAL1Iter == iterParams.maxKAL1Iter + ? iterParams.kAL1Tail + : FmapL1TileShape::Kd * FmapL1TileShape::Ci1 * conv3dParams.khkwcin0(); + uint32_t orgCin1LoadL1 = kAL1Tmp / conv3dParams.kh() / conv3dParams.kw() / conv3dParams.cin0(); + uint32_t kdL1Idx = (currentCin1LoadL1 / conv3dParams.cin1()) % conv3dParams.kd(); + uint32_t cin1L1Idx = currentCin1LoadL1 % conv3dParams.cin1(); + uint32_t padTopL1 = 0; + uint32_t padBottomL1 = 0; + iterParams.aL1IsFullPad = false; + bool set2dFlagDHead = false; + bool set2dFlagDTail = false; + uint32_t hiLoadL1 = orgHiLoadL1; + uint32_t cin1LoadL1 = orgCin1LoadL1; + + uint32_t hiStartIdxWithPad = hoStartIdx * conv3dParams.sH(); + uint32_t hiEndIdxWithPad = hiStartIdxWithPad + hiLoadL1; + uint32_t hiIdx = hiStartIdxWithPad - conv3dParams.padtop() - curCoreHiStartIdx; + uint32_t hiWithPad = conv3dParams.hi() + conv3dParams.padtop(); + if (hiEndIdxWithPad <= conv3dParams.padtop()) { + iterParams.aL1IsFullPad = true; + } else if (hiStartIdxWithPad < conv3dParams.padtop()) { + hiIdx = 0; + hiLoadL1 = hiLoadL1 + hiStartIdxWithPad - conv3dParams.padtop(); + padTopL1 = conv3dParams.padtop() - hiStartIdxWithPad; + if (hiEndIdxWithPad >= hiWithPad) { + hiLoadL1 = conv3dParams.hi() - hiIdx; + padBottomL1 = hiEndIdxWithPad - hiWithPad; + } + } else if (hiStartIdxWithPad >= hiWithPad) { + iterParams.aL1IsFullPad = true; + } else if (hiEndIdxWithPad > hiWithPad) { + hiLoadL1 = hiWithPad - hiStartIdxWithPad; + padBottomL1 = hiEndIdxWithPad - hiWithPad; + } + + uint32_t diStartWithPad = + iterParams.diStartPos + iterParams.dOutIter * conv3dParams.sD() + kdL1Idx * conv3dParams.dD(); + uint32_t diEndWithPad = cin1LoadL1 <= conv3dParams.cin1() + ? diStartWithPad + 1 + : diStartWithPad + (cin1LoadL1 / conv3dParams.cin1() - 1) * conv3dParams.dD() + 1; + uint32_t diIdx = iterParams.diStartPos <= conv3dParams.padhead() ? diStartWithPad - conv3dParams.padhead() + : diStartWithPad - iterParams.diStartPos; + uint32_t diWithPad = conv3dParams.di() + conv3dParams.padhead(); + uint32_t cin1LoadL1PadHead = 0; + uint32_t cin1LoadL1PadTail = 0; + if (diEndWithPad <= conv3dParams.padhead()) { + iterParams.aL1IsFullPad = true; + } else if (diStartWithPad < conv3dParams.padhead()) { + set2dFlagDHead = true; + uint32_t kdTmp = CeilDiv((conv3dParams.padhead() - diStartWithPad), conv3dParams.dD()); + cin1LoadL1PadHead = kdTmp * conv3dParams.cin1(); + diIdx = conv3dParams.dD() == 1 ? 0 : kdTmp * conv3dParams.dD() - conv3dParams.padhead() + diStartWithPad; + cin1LoadL1 -= cin1LoadL1PadHead; + + if (diEndWithPad > diWithPad) { + set2dFlagDTail = true; + kdTmp = CeilDiv((conv3dParams.di() - diIdx), conv3dParams.dD()); + cin1LoadL1PadTail = cin1LoadL1 - kdTmp * conv3dParams.cin1(); + cin1LoadL1 = kdTmp * conv3dParams.cin1(); + } + } else if (diStartWithPad >= diWithPad) { + iterParams.aL1IsFullPad = true; + } else if (diEndWithPad > diWithPad) { + set2dFlagDTail = true; + uint32_t kdTmp = CeilDiv((diWithPad - diStartWithPad), conv3dParams.dD()); + cin1LoadL1PadTail = cin1LoadL1 - kdTmp * conv3dParams.cin1(); + cin1LoadL1 = kdTmp * conv3dParams.cin1(); + } + if (!iterParams.aL1IsFullPad) { + uint8_t padList[PAD_SIZE] = {0}; + padList[PAD_IDX_L] = conv3dParams.padleft(); + padList[PAD_IDX_R] = conv3dParams.padright(); + padList[PAD_IDX_T] = padTopL1; + padList[PAD_IDX_B] = padBottomL1; + SetFmatrix(hiLoadL1, conv3dParams.wi(), padList, AscendC::FmatrixMode::FMATRIX_LEFT); + + uint64_t aL1Offset = 0; + if (set2dFlagDHead) { + AscendC::InitConstValueParams initConstValueParams; + initConstValueParams.repeatTimes = cin1LoadL1PadHead / conv3dParams.cin1(); + initConstValueParams.blockNum = conv3dParams.cin1() * hiLoadL1 * conv3dParams.wi(); + initConstValueParams.dstGap = 0; + initConstValueParams.initValue = 0; + InitConstValue(l1ATensorList[l1ListId], initConstValueParams); + aL1Offset += cin1LoadL1PadHead * hiLoadL1 * conv3dParams.wicin0(); + set2dFlagDHead = false; + } + + Conv3d6HdCoord gmTileFmapOffset{0, diIdx, cin1L1Idx, hiIdx * conv3dParams.wi()}; + auto layoutTileFmap = + layoutFmap.GetTileLayout(MakeCoord((uint32_t)1, conv3dParams.dD(), conv3dParams.cin1(), + conv3dParams.hi(), conv3dParams.wi(), conv3dParams.cin0())); + auto gmTileFmap = gmBatchFmap[layoutTileFmap.GetOffset(gmTileFmapOffset)]; + layoutFmapInL1 = + LayoutFmapInL1::MakeLayout(1, 1, cin1LoadL1, hiLoadL1, conv3dParams.wi(), conv3dParams.cin0()); + + copyGmToL1A(l1ATensorList[l1ListId][aL1Offset], gmTileFmap, layoutFmapInL1, layoutTileFmap); + + if (set2dFlagDTail) { + aL1Offset += cin1LoadL1 * hiLoadL1 * conv3dParams.wi() * conv3dParams.cin0(); + AscendC::InitConstValueParams initConstValueParams; + initConstValueParams.repeatTimes = cin1LoadL1PadTail / conv3dParams.cin1(); + initConstValueParams.blockNum = conv3dParams.cin1() * hiLoadL1 * conv3dParams.wi(); + initConstValueParams.dstGap = 0; + initConstValueParams.initValue = 0; + InitConstValue(l1ATensorList[l1ListId][aL1Offset], initConstValueParams); + set2dFlagDTail = false; + } + } + iterParams.loadAL1Flag = false; + layoutFmapInL1 = + LayoutFmapInL1::MakeLayout(1, 1, orgCin1LoadL1, orgHiLoadL1, conv3dParams.wi(), conv3dParams.cin0()); + } + + CATLASS_DEVICE + void LoadBL1Process(AscendC::GlobalTensor const &filterGm, uint32_t kBL1Iter, + LayoutFilter const &layoutFilter) + { + iterParams.kBL1Iter = kBL1Iter; + uint32_t currentNBL1 = + ((iterParams.nBL1Iter != iterParams.maxNBL1Iter) || (FilterL1TileShape::nBL1 >= conv3dParams.alignCout())) + ? FilterL1TileShape::nBL1 + : iterParams.nBL1TailAlign; + uint32_t currentKBL1 = iterParams.kBL1Iter == iterParams.maxKBL1Iter + ? iterParams.kBL1Tail + : FilterL1TileShape::Kd * FilterL1TileShape::Ci1 * conv3dParams.khkwcin0(); + Conv3dFracZ3dCoord gmTileFilterOffset{ + iterParams.kBL1Iter * FilterL1TileShape::Kd * FilterL1TileShape::Ci1 * conv3dParams.khkw(), + iterParams.nBL1Iter * FilterL1TileShape::nBL1}; + auto layoutTileFilter = layoutFilter; + auto gmTileFiler = filterGm[layoutTileFilter.GetOffset(gmTileFilterOffset)]; + layoutFilterInL1 = LayoutFilterInL1::template MakeLayout(currentKBL1, currentNBL1); + copyGmToL1B(l1BTensorList[l1ListId], gmTileFiler, layoutFilterInL1, layoutTileFilter); + iterParams.loadBL1Flag = false; + } + + CATLASS_DEVICE + void ReduceKL0AL0BPingPong(const uint16_t &l0abFlag) + { + auto l0ATile = l0ATensorList[l0abFlag]; + auto l0BTile = l0BTensorList[l0abFlag]; + AscendC::WaitFlag(l0AEventList[l0abFlag]); + AscendC::WaitFlag(l0BEventList[l0abFlag]); + iterParams.kAL0Iter = iterParams.kIter % iterParams.multiKAL1; + uint32_t currentKL0 = iterParams.kIter == iterParams.maxKL0Iter ? iterParams.kL0Tail : L0TileShape::kL0; + if (iterParams.aL1IsFullPad) { + uint32_t al0Set2dSpacesize_ = tileParams.l0CurrentM * currentKL0 * sizeof(ElementFmap) / BLOCK_SIZE; + AscendC::InitConstValueParams initConstValueParams(1, (uint16_t)al0Set2dSpacesize_, 0, 0); + InitConstValue(l0ATensorList[l0abFlag], initConstValueParams); + } else { + uint32_t kStartPt = iterParams.kAL0Iter * L0TileShape::kL0; + uint32_t mStartPt = + iterParams.mL0IsDivisibleByWo + ? iterParams.mAL0Iter * L0TileShape::mL0 + iterParams.hwStartPos % conv3dParams.wo() + : iterParams.mAL0Iter * L0TileShape::mL0 + + (iterParams.hwStartPos + iterParams.mAL1Iter * FmapL1TileShape::mAL1) % conv3dParams.wo(); + LayoutAInL0 layoutAInL0 = + LayoutAInL0::template MakeLayout(tileParams.l0CurrentM, currentKL0); + copyL1ToL0A(l0ATile, l1ATensorList[l1ListId], layoutAInL0, layoutFmapInL1, kStartPt, mStartPt); + } + iterParams.kBL0Iter = iterParams.kIter % iterParams.multiKBL1; + uint32_t tilingNBSrc_ = + (iterParams.nBL1Iter != iterParams.maxNBL1Iter) ? FilterL1TileShape::nBL1 : iterParams.nBL1TailAlign; + MatrixCoord l1TileFilterOffset{iterParams.kBL0Iter * L0TileShape::kL0, iterParams.nBL0Iter * L0TileShape::nL0}; + auto l1BTile = l1BTensorList[l1ListId][layoutFilterInL1.GetOffset(l1TileFilterOffset)]; + LayoutBInL0 layoutBInL0 = LayoutBInL0::template MakeLayout(currentKL0, tileParams.l0CurrentN); + copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, layoutFilterInL1); + AscendC::SetFlag(l0AEventList[l0abFlag]); + AscendC::SetFlag(l0BEventList[l0abFlag]); + AscendC::WaitFlag(l0AEventList[l0abFlag]); + AscendC::WaitFlag(l0BEventList[l0abFlag]); + auto l0CTile = l0CTensorList[l0cListId]; + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if (iterParams.kIter == iterParams.ddr2l0LoopK - 1) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + if (iterParams.kIter == 0) { + tileMmad(l0CTile, l0ATile, l0BTile, l0BiasTensor, tileParams.l0CurrentM, tileParams.l0CurrentN, currentKL0, + true, unitFlag); + } else { + tileMmad(l0CTile, l0ATile, l0BTile, tileParams.l0CurrentM, tileParams.l0CurrentN, currentKL0, false, + unitFlag); + } + AscendC::SetFlag(l0AEventList[l0abFlag]); + AscendC::SetFlag(l0BEventList[l0abFlag]); + } +}; +} // namespace Catlass::Conv::Block + +#endif // CATLASS_CONV_BLOCK_BLOCK_CONV3D_PINGPONG_BIAS_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/block/block_swizzle.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/block/block_swizzle.hpp new file mode 100644 index 00000000..d253e677 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/block/block_swizzle.hpp @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_CONV_BLOCK_BLOCK_SWIZZLE_HPP +#define CATLASS_CONV_BLOCK_BLOCK_SWIZZLE_HPP + +#include "catlass/catlass.hpp" +#include "catlass/detail/alignment.hpp" +#include "catlass/conv_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Conv::Block { +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// Block swizzling function for conv3d +template +struct Conv3dIdentityBlockSwizzle { + /// Data members + Conv3d6HdCoord outShape; + Conv3d6HdCoord coreTileShape; + MatrixCoord loopsMN; + Conv3d6HdCoord loops; + uint64_t nStart, doStart, co1Start, howoStart; + + // Methods + CATLASS_DEVICE + Conv3dIdentityBlockSwizzle() {} + + CATLASS_DEVICE + Conv3dIdentityBlockSwizzle(Conv3d6HdCoord const &outShape_, Conv3d6HdCoord const &loops_) + : outShape(outShape_), loops(loops_) + { + loops = Conv3d6HdCoord{min(outShape.n(), loops.n()), min(outShape.d(), loops.d()), + min(outShape.c1(), loops.c1()), min(outShape.hw(), loops.hw())}; + coreTileShape = Conv3d6HdCoord{CeilDiv(outShape.n(), loops.n()), CeilDiv(outShape.d(), loops.d()), + CeilDiv(outShape.c1(), loops.c1()), CeilDiv(outShape.hw(), loops.hw())}; + loopsMN = MatrixCoord{loops.hw(), loops.n() * loops.d() * loops.c1()}; + } + + CATLASS_DEVICE + uint32_t GetCoreLoops() const + { + return loopsMN.row() * loopsMN.column(); + } + + CATLASS_DEVICE + Conv3d6HdCoord GetBlockCoord(uint32_t taskIdx) + { + uint32_t innerIdx = taskIdx % GetCoreLoops(); + if constexpr (SwizzleDirection == 0) { + uint32_t tileBlockLoop = CeilDiv(loopsMN.row(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMN.column()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMN.column()); + + uint32_t nRow = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nRow = loopsMN.row() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nRow; + uint32_t nIdx = inTileBlockIdx / nRow; + if (tileBlockIdx % 2 == 1) { + nIdx = loopsMN.column() - nIdx - 1; + } + + uint32_t howoIdx = mIdx; + + uint32_t noIdx = nIdx / (loops[1] * loops[2]); + uint32_t doIdx = nIdx / loops[2] % loops[1]; + uint32_t c1Idx = nIdx % loops[2]; + return Conv3d6HdCoord{noIdx, doIdx, c1Idx, howoIdx}; + } else if constexpr (SwizzleDirection == 1) { + uint32_t tileBlockLoop = CeilDiv(loopsMN.column(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMN.row()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMN.row()); + + uint32_t nCol = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nCol = loopsMN.column() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = inTileBlockIdx / nCol; + uint32_t nIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nCol; + if (tileBlockIdx % 2 == 1) { + mIdx = loopsMN.row() - mIdx - 1; + } + + uint32_t howoIdx = mIdx; + + uint32_t noIdx = nIdx / (loops[1] * loops[2]); + uint32_t doIdx = nIdx / loops[2] % loops[1]; + uint32_t c1Idx = nIdx % loops[2]; + return Conv3d6HdCoord{noIdx, doIdx, c1Idx, howoIdx}; + } + } + + CATLASS_DEVICE + Conv3d6HdCoord GetDimStartIdx(Conv3d6HdCoord blockCoord) + { + uint32_t nStart = blockCoord.n() * coreTileShape.n(); + uint32_t doStart = blockCoord.d() * coreTileShape.d(); + uint32_t c1Start = blockCoord.c1() * coreTileShape.c1(); + uint32_t howoStart = blockCoord.hw() * coreTileShape.hw(); + return Conv3d6HdCoord{nStart, doStart, c1Start, howoStart}; + } + + CATLASS_DEVICE + Conv3d6HdCoord GetActualBlockShape(Conv3d6HdCoord blockCoord, Conv3d6HdCoord dimStartIdx) + { + uint32_t nActual = (blockCoord.n() == loops.n() - 1) ? (outShape[0] - dimStartIdx.n()) : coreTileShape.n(); + + uint32_t doActual = (blockCoord.d() == loops.d() - 1) ? (outShape[1] - dimStartIdx.d()) : coreTileShape.d(); + + uint32_t c1Actual = (blockCoord.c1() == loops.c1() - 1) ? (outShape[2] - dimStartIdx.c1()) : coreTileShape.c1(); + + uint32_t hwActual = (blockCoord.hw() == loops.hw() - 1) ? (outShape[3] - dimStartIdx.hw()) : coreTileShape.hw(); + return Conv3d6HdCoord{nActual, doActual, c1Actual, hwActual}; + } +}; +} // namespace Catlass::Conv::Block + +#endif // CATLASS_CONV_BLOCK_BLOCK_SWIZZLE_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/device/device_conv.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/device/device_conv.hpp new file mode 100644 index 00000000..a23b033f --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/device/device_conv.hpp @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_CONV_DEVICE_DEVICE_CONV_HPP +#define CATLASS_CONV_DEVICE_DEVICE_CONV_HPP + +#include +#include "catlass/catlass.hpp" +#include "catlass/status.hpp" +#include "catlass/conv/device/kernel_adapter.hpp" + +#if defined(ENABLE_ASCENDC_DUMP) +#include "catlass/debug.hpp" +#endif + +namespace Catlass::Conv::Device { + +template +class DeviceConv +{ +public: + /// Argument structure: User API + using Arguments = typename ConvKernel::Arguments; + /// Argument structure: Kernel API + using Params = typename ConvKernel::Params; + +private: + /// kernel API parameters object + Params params_; + +public: + DeviceConv() {} + ~DeviceConv() {} + + /// Access the Params structure + Params const ¶ms() const + { + return params_; + } + + /// Determines whether the Conv can execute the given problem. + static Status CanImplement(Arguments const &args) + { + if (ConvKernel::CanImplement(args)) { + return Status::kSuccess; + } else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t GetWorkspaceSize(Arguments const &args) + { + size_t workspace_bytes = 0; + workspace_bytes += ConvKernel::GetWorkspaceSize(args); + return workspace_bytes; + } + + /// Initializes Conv state from arguments + Status Initialize(Arguments const &args, uint8_t *workspace = nullptr, aclrtStream stream = nullptr) + { + // Initialize the Params structure + params_ = ConvKernel::ToUnderlyingArguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling matmul Kernel::to_underling arguments + inline Status Run(aclrtStream stream, uint32_t blockDim, uint64_t fftsAddr) + { +#if defined(ENABLE_ASCENDC_DUMP) + uint8_t *ptrDump{nullptr}; + aclCheck(aclrtMalloc(reinterpret_cast(&ptrDump), ALL_DUMPSIZE, ACL_MEM_MALLOC_HUGE_FIRST)); + if (fftsAddr == 0) { + Catlass::KernelAdapter<<>>(params_, ptrDump); + } else { + Catlass::KernelAdapter<<>>(params_, fftsAddr, ptrDump); + } + aclCheck(aclrtSynchronizeStream(stream)); + Adx::AdumpPrintWorkSpace(ptrDump, ALL_DUMPSIZE, stream, "device_gemm"); + aclCheck(aclrtFree(ptrDump)); +#else + if (fftsAddr == 0) { + Catlass::KernelAdapter<<>>(params_); + } else { + Catlass::KernelAdapter<<>>(params_, fftsAddr); + } +#endif + return Status::kSuccess; + } + + /// Runs the kernel using initialized state + inline Status operator()(aclrtStream stream, uint32_t blockDim) + { + return Run(stream, blockDim, 0); + } + + inline Status operator()(aclrtStream stream, uint32_t blockDim, uint64_t fftsAddr) + { + return Run(stream, blockDim, fftsAddr); + } +}; +/////////////////////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Conv::Device +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/device/kernel_adapter.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/device/kernel_adapter.hpp new file mode 100644 index 00000000..d8649ff3 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/device/kernel_adapter.hpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef CATLASS_CONV_DEVICE_KERNEL_ADAPTER_HPP +#define CATLASS_CONV_DEVICE_KERNEL_ADAPTER_HPP + +#include "catlass/catlass.hpp" + +#if defined(ENABLE_ASCENDC_DUMP) +#include "catlass/debug.hpp" +#endif + +namespace Catlass { +/// Generic Catlass kernel template +template +CATLASS_GLOBAL void KernelAdapter(typename Operator::Params params, GM_ADDR ptrDump = nullptr) +{ + Operator op; +#if defined(ENABLE_ASCENDC_DUMP) + AscendC::InitDump(false, ptrDump, ALL_DUMPSIZE); +#endif + op(params); +} + +template +CATLASS_GLOBAL void KernelAdapter(typename Operator::Params params, uint64_t fftsAddr, GM_ADDR ptrDump = nullptr) +{ + AscendC::SetSyncBaseAddr(fftsAddr); + Operator op; +#if defined(ENABLE_ASCENDC_DUMP) + AscendC::InitDump(false, ptrDump, ALL_DUMPSIZE); +#endif + op(params); +} +} // namespace Catlass + +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/dispatch_policy.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/dispatch_policy.hpp new file mode 100644 index 00000000..386b051a --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/dispatch_policy.hpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_CONV_DISPATCH_POLICY_HPP +#define CATLASS_CONV_DISPATCH_POLICY_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" + +namespace Catlass::Conv { + +// Block Mmad Policies + +template +struct ConvAtlasA2Base { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t ASYNC = ASYNC_; +}; + +using ConvAtlasA2 = ConvAtlasA2Base; +using ConvAtlasA2Async = ConvAtlasA2Base; + +template +struct ConvAtlasA2Pingpong : public ConvAtlasA2 { + static constexpr uint32_t L1A_STAGES = L1A_STAGES_; + static constexpr uint32_t L1B_STAGES = L1B_STAGES_; + static constexpr uint32_t L0A_STAGES = L0A_STAGES_; + static constexpr uint32_t L0B_STAGES = L0B_STAGES_; + static constexpr uint32_t L0C_STAGES = L0C_STAGES_; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; +}; + +} // namespace Catlass::Conv + +#endif // CATLASS_CONV_DISPATCH_POLICY_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/kernel/conv3d_bias.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/kernel/conv3d_bias.hpp new file mode 100644 index 00000000..b2c6b4a4 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv/kernel/conv3d_bias.hpp @@ -0,0 +1,179 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_CONV_KERNEL_BASIC_CONV3D_HPP +#define CATLASS_CONV_KERNEL_BASIC_CONV3D_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/conv_coord.hpp" +#include "catlass/coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Conv::Kernel { + +// Template for conv3d with bias kernel. +template +class ConvBias +{ +public: + using BlockConv = BlockConv_; + using ArchTag = typename BlockConv::ArchTag; + using CoreTileShape = typename BlockConv::CoreTileShape; + using ElementFmap = typename BlockConv::ElementFmap; + using LayoutFmap = typename BlockConv::LayoutFmap; + using ElementFilter = typename BlockConv::ElementFilter; + using LayoutFilter = typename BlockConv::LayoutFilter; + using ElementOut = typename BlockConv::ElementOut; + using LayoutOut = typename BlockConv::LayoutOut; + using ElementBias = typename BlockConv::ElementBias; + using ElementAccumulator = typename BlockConv::ElementAccumulator; + + // using Conv3dParams = typename Catlass::Conv3dParams; + + using BlockScheduler = BlockScheduler_; + + struct Params { + // Data members + Conv3dParams problemShape; + GM_ADDR ptrFmap; + LayoutFmap layoutFmap; + GM_ADDR ptrFilter; + LayoutFilter layoutFilter; + GM_ADDR ptrOut; + LayoutOut layoutOut; + GM_ADDR ptrBias; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(Conv3dParams const &problemShape_, GM_ADDR ptrFmap_, LayoutFmap const &layoutFmap_, GM_ADDR ptrFilter_, + LayoutFilter const &layoutFilter_, GM_ADDR ptrOut_, LayoutOut const &layoutOut_, GM_ADDR ptrBias_) + : problemShape(problemShape_), + ptrFmap(ptrFmap_), + layoutFmap(layoutFmap_), + ptrFilter(ptrFilter_), + layoutFilter(layoutFilter_), + ptrOut(ptrOut_), + layoutOut(layoutOut_), + ptrBias(ptrBias_) + {} + }; + + struct Arguments { + Conv3dParams problemShape; + GM_ADDR ptrFmap; + GM_ADDR ptrFilter; + GM_ADDR ptrOut; + GM_ADDR ptrBias; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + return 0; + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + LayoutFmap layoutFmap = + LayoutFmap::MakeLayout(args.problemShape.batch(), args.problemShape.di(), args.problemShape.cin1(), + args.problemShape.hi(), args.problemShape.wi(), args.problemShape.cin0()); + LayoutFilter layoutFilter = LayoutFilter::MakeLayout(args.problemShape.kdc1khkw(), args.problemShape.n1(), + args.problemShape.n0(), args.problemShape.cin0()); + LayoutOut layoutOut = + LayoutOut::MakeLayout(args.problemShape.batch(), args.problemShape.dout(), args.problemShape.cout1(), + args.problemShape.ho(), args.problemShape.wo(), args.problemShape.cout0()); + Params params{args.problemShape, args.ptrFmap, layoutFmap, args.ptrFilter, + layoutFilter, args.ptrOut, layoutOut, args.ptrBias}; + return params; + } + + // Methods + CATLASS_DEVICE + ConvBias() {} + + template + CATLASS_DEVICE ConvBias() + {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler convBlockScheduler( + Conv3d6HdCoord{params.problemShape.batch(), params.problemShape.dout(), params.problemShape.cout1(), + params.problemShape.howo()}, + Conv3d6HdCoord{CoreTileShape::noCnt, CoreTileShape::doCnt, CoreTileShape::co1Cnt, CoreTileShape::howoCnt}); + uint32_t coreLoops = convBlockScheduler.GetCoreLoops(); + + Arch::Resource resource; + BlockConv blockConv(resource, params.problemShape); + + // represent the full gm + AscendC::GlobalTensor fmapGm; + fmapGm.SetGlobalBuffer((__gm__ ElementFmap *)params.ptrFmap); + AscendC::GlobalTensor filterGm; + filterGm.SetGlobalBuffer((__gm__ ElementFilter *)params.ptrFilter); + AscendC::GlobalTensor outGm; + outGm.SetGlobalBuffer((__gm__ ElementOut *)params.ptrOut); + AscendC::GlobalTensor biasGm; + biasGm.SetGlobalBuffer((__gm__ ElementBias *)params.ptrBias); + + for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + Conv3d6HdCoord blockCoord = convBlockScheduler.GetBlockCoord(loopIdx); + Conv3d6HdCoord dimStartCoord = convBlockScheduler.GetDimStartIdx(blockCoord); + Conv3d6HdCoord actualBlockShape = convBlockScheduler.GetActualBlockShape(blockCoord, dimStartCoord); + + uint32_t diIdxStart = Max(dimStartCoord.d() * params.problemShape.sD(), params.problemShape.padhead(), 0); + uint32_t hiwiIdxStart = Max((dimStartCoord.hw() / params.problemShape.wo()) * params.problemShape.sH(), + params.problemShape.padtop(), 0) * + params.problemShape.wi(); + + // Compute initial location in logical coordinates + Conv3d6HdCoord offsetFmap{dimStartCoord.n(), diIdxStart, 0, hiwiIdxStart}; + Conv3dFracZ3dCoord offsetFilter{0, dimStartCoord.c1() * params.problemShape.cout0()}; + Conv3d6HdCoord offsetOut{dimStartCoord.n(), dimStartCoord.d(), dimStartCoord.c1(), dimStartCoord.hw()}; + Conv3d6HdCoord actualIdxStartFmap{0, dimStartCoord.d() * params.problemShape.sD(), 0, dimStartCoord.hw()}; + + int64_t gmOffsetFmap = params.layoutFmap.GetOffset(offsetFmap); + int64_t gmOffsetFilter = params.layoutFilter.GetOffset(offsetFilter); + int64_t gmOffsetOut = params.layoutOut.GetOffset(offsetOut); + int64_t gmOffsetBias = dimStartCoord.c1() * params.problemShape.cout0(); + + blockConv(fmapGm[gmOffsetFmap], params.layoutFmap, filterGm[gmOffsetFilter], params.layoutFilter, + outGm[gmOffsetOut], params.layoutOut, biasGm[gmOffsetBias], actualBlockShape, actualIdxStartFmap); + } + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + {} + + __aicore__ inline uint32_t Max(uint32_t a, uint32_t b, uint32_t c) + { + if (a > b) { + return a - b; + } else { + return c; + } + } +}; + +} // namespace Catlass::Conv::Kernel +#endif // CATLASS_CONV_KERNEL_BASIC_CONV3D_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv_coord.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv_coord.hpp new file mode 100644 index 00000000..f61449bd --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/conv_coord.hpp @@ -0,0 +1,460 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_CONV_COORD_HPP +#define CATLASS_CONV_COORD_HPP + +#include "catlass/coord.hpp" + +namespace Catlass { + +/// Shape of conv3d operation +struct Conv3dParams { +public: + typedef uint32_t Index; + static constexpr uint32_t N0 = 16; + using Fmap6HDShape = Coord<6, Index>; // {batch, di, cin1, hi, wi, cin0} + using FilterFracZ3DShape = Coord<7, Index>; // {kd, cin1, kh, kw, n1, n0, cin0} + using Out6HDShape = Coord<6, Index>; // {batch, do, cout1, ho, wo, cout0} + using Strides = Coord<3, Index>; + using Pads = Coord<3, Index>; + using Dilations = Coord<3, Index>; + +private: + Fmap6HDShape fmap6HDShape_; + FilterFracZ3DShape filterFracZ3DShape_; + Out6HDShape out6HDShape_; + Strides strides_; + Pads pads_; + Dilations dilations_; + Index cout_; + +public: + CATLASS_HOST_DEVICE + Conv3dParams(Index BATCH = 1, Index Di = 1, Index Cin1 = 1, Index Hi = 1, Index Wi = 1, Index C0 = 16, Index Kd = 1, + Index Kh = 1, Index Kw = 1, Index N1 = 1, Index Do = 1, Index Ho = 1, Index Wo = 1, Index Cout1 = 1, + Index Cout = 1, Index padHead = 0, Index padTop = 0, Index padLeft = 0, Index strideD = 1, + Index strideH = 1, Index strideW = 1, Index dilationD = 1, Index dilationH = 1, Index dilationW = 1) + : fmap6HDShape_(MakeCoord(BATCH, Di, Cin1, Hi, Wi, C0)), + filterFracZ3DShape_(MakeCoord(Kd, Cin1, Kh, Kw, N1, N0, C0)), + out6HDShape_(MakeCoord(BATCH, Do, Cout1, Ho, Wo, C0)), + cout_(Cout), + pads_(MakeCoord(padHead, padTop, padLeft)), + strides_(MakeCoord(strideD, strideH, strideW)), + dilations_(MakeCoord(dilationD, dilationH, dilationW)) + {} + + CATLASS_HOST_DEVICE + static Conv3dParams MakeConvCoord(const uint32_t *fmapShape, const uint32_t *filterShape, const uint32_t *paddings, + const uint32_t *strides, const uint32_t *dilations) + { + return Conv3dParams( + fmapShape[0], fmapShape[1], fmapShape[2], fmapShape[3], fmapShape[4], fmapShape[5], filterShape[0], + filterShape[1], filterShape[2], CeilDiv(filterShape[3], N0), + (fmapShape[1] + paddings[0] * 2 - dilations[0] * (filterShape[0] - 1) - 1) / strides[0] + 1, // Do + (fmapShape[3] + paddings[1] * 2 - dilations[1] * (filterShape[1] - 1) - 1) / strides[1] + 1, // Ho + (fmapShape[4] + paddings[2] * 2 - dilations[2] * (filterShape[2] - 1) - 1) / strides[2] + 1, // Wo + CeilDiv(filterShape[3], fmapShape[5]), filterShape[3], paddings[0], paddings[1], paddings[2], strides[0], + strides[1], strides[2], dilations[0], dilations[1], dilations[2]); + } + + // fmapShape + CATLASS_HOST_DEVICE + Index const &batch() const + { + return fmap6HDShape_[0]; + } + CATLASS_HOST_DEVICE + Index const &cin1() const + { + return fmap6HDShape_[2]; + } + CATLASS_HOST_DEVICE + Index const &di() const + { + return fmap6HDShape_[1]; + } + CATLASS_HOST_DEVICE + Index const &hi() const + { + return fmap6HDShape_[3]; + } + CATLASS_HOST_DEVICE + Index const &wi() const + { + return fmap6HDShape_[4]; + } + CATLASS_HOST_DEVICE + Index const &cin0() const + { + return fmap6HDShape_[5]; + } + CATLASS_HOST_DEVICE + Index const hiwi() const + { + return fmap6HDShape_[3] * fmap6HDShape_[4]; + } + + // filterShape + CATLASS_HOST_DEVICE + Index const &kd() const + { + return filterFracZ3DShape_[0]; + } + CATLASS_HOST_DEVICE + Index const &kh() const + { + return filterFracZ3DShape_[2]; + } + CATLASS_HOST_DEVICE + Index const &kw() const + { + return filterFracZ3DShape_[3]; + } + CATLASS_HOST_DEVICE + Index const khkw() const + { + return filterFracZ3DShape_[2] * filterFracZ3DShape_[3]; + } + CATLASS_HOST_DEVICE + Index const kdc1khkw() const + { + return filterFracZ3DShape_[0] * filterFracZ3DShape_[1] * filterFracZ3DShape_[2] * filterFracZ3DShape_[3]; + } + CATLASS_HOST_DEVICE + Index const &n1() const + { + return filterFracZ3DShape_[4]; + } + CATLASS_HOST_DEVICE + Index const &n0() const + { + return filterFracZ3DShape_[5]; + } + + // outShape + CATLASS_HOST_DEVICE + Index const &dout() const + { + return out6HDShape_[1]; + } + CATLASS_HOST_DEVICE + Index const &ho() const + { + return out6HDShape_[3]; + } + CATLASS_HOST_DEVICE + Index const &wo() const + { + return out6HDShape_[4]; + } + CATLASS_HOST_DEVICE + Index const &cout1() const + { + return out6HDShape_[2]; + } + CATLASS_HOST_DEVICE + Index const &cout0() const + { + return out6HDShape_[5]; + } + CATLASS_HOST_DEVICE + Index const &cout() const + { + return cout_; + } + + /// paddings + CATLASS_HOST_DEVICE + Index const &padhead() const + { + return pads_[0]; + } + CATLASS_HOST_DEVICE + Index const &padtail() const + { + return pads_[0]; + } + CATLASS_HOST_DEVICE + Index const &padtop() const + { + return pads_[1]; + } + CATLASS_HOST_DEVICE + Index const &padbottom() const + { + return pads_[1]; + } + CATLASS_HOST_DEVICE + Index const &padleft() const + { + return pads_[2]; + } + CATLASS_HOST_DEVICE + Index const &padright() const + { + return pads_[2]; + } + + /// strideSize + CATLASS_HOST_DEVICE + Index const &sD() const + { + return strides_[0]; + } + CATLASS_HOST_DEVICE + Index const &sH() const + { + return strides_[1]; + } + CATLASS_HOST_DEVICE + Index const &sW() const + { + return strides_[2]; + } + + /// dilationSize + CATLASS_HOST_DEVICE + Index const &dD() const + { + return dilations_[0]; + } + CATLASS_HOST_DEVICE + Index const dilatedKernelD() const + { + return 1 + (filterFracZ3DShape_[0] - 1) * dilations_[0]; + } + CATLASS_HOST_DEVICE + Index const &dH() const + { + return dilations_[1]; + } + CATLASS_HOST_DEVICE + Index const dilatedKernelH() const + { + return 1 + (filterFracZ3DShape_[2] - 1) * dilations_[1]; + } + CATLASS_HOST_DEVICE + Index const &dW() const + { + return dilations_[2]; + } + CATLASS_HOST_DEVICE + Index const dilatedKernelW() const + { + return 1 + (filterFracZ3DShape_[3] - 1) * dilations_[2]; + } + + ///// used in block + CATLASS_HOST_DEVICE + Index const howo() const + { + return out6HDShape_[3] * out6HDShape_[4]; + } + CATLASS_HOST_DEVICE + Index const alignCout() const + { + return out6HDShape_[2] * out6HDShape_[5]; + } + CATLASS_HOST_DEVICE + Index const wicin0() const + { + return fmap6HDShape_[4] * fmap6HDShape_[5]; + } + CATLASS_HOST_DEVICE + Index const khkwcin0() const + { + return filterFracZ3DShape_[2] * filterFracZ3DShape_[3] * filterFracZ3DShape_[6]; + } + CATLASS_HOST_DEVICE + Index const alignCinKhKwKd() const + { + return filterFracZ3DShape_[0] * filterFracZ3DShape_[1] * filterFracZ3DShape_[2] * filterFracZ3DShape_[3] * + filterFracZ3DShape_[6]; + } + CATLASS_HOST_DEVICE + Index const kdcin1() const + { + return filterFracZ3DShape_[0] * filterFracZ3DShape_[1]; + } + CATLASS_HOST_DEVICE + Index const fmapOneBatchSize() const + { + return fmap6HDShape_[1] * fmap6HDShape_[2] * fmap6HDShape_[3] * fmap6HDShape_[4] * fmap6HDShape_[5]; + } + CATLASS_HOST_DEVICE + Index const outputOneBatchSize() const + { + return out6HDShape_[1] * out6HDShape_[2] * out6HDShape_[3] * out6HDShape_[4] * out6HDShape_[5]; + } +}; + +template +struct ConvCoreShape { + static uint32_t const noCnt = noCnt_; + static uint32_t const doCnt = doCnt_; + static uint32_t const co1Cnt = co1Cnt_; + static uint32_t const howoCnt = howoCnt_; + + /// Returns a Coord object + CATLASS_HOST_DEVICE + static Coord<4> ToCoord() + { + return MakeCoord(noCnt, doCnt, co1Cnt, howoCnt); + } +}; + +template +struct ConvFmapL1Shape { + static uint32_t constexpr mAL1 = mAL1_; + static uint32_t constexpr Kd = Kd_; + static uint32_t constexpr Ci1 = Ci1_; + + /// Returns a Coord object + CATLASS_HOST_DEVICE + static Coord<3> ToCoord() + { + return MakeCoord(mAL1, Kd, Ci1); + } +}; + +template +struct ConvFilterL1Shape { + static uint32_t constexpr Kd = Kd_; + static uint32_t constexpr Ci1 = Ci1_; + static uint32_t constexpr nBL1 = nBL1_; + + /// Returns a Coord object + CATLASS_HOST_DEVICE + static Coord<3> ToCoord() + { + return MakeCoord(Kd, Ci1, nBL1); + } +}; + +template +struct ConvL0Shape { + static uint32_t constexpr mL0 = mL0_; + static uint32_t constexpr kL0 = kL0_; + static uint32_t constexpr nL0 = nL0_; + + /// Returns a Coord object + CATLASS_HOST_DEVICE + static Coord<3> ToCoord() + { + return MakeCoord(mL0, kL0, nL0); + } +}; + +struct Conv3d6HdCoord : public Coord<4, uint32_t> { + using Index = uint32_t; + + using Base = Coord<4, Index>; + + static constexpr int N_INDEX = 0; + static constexpr int D_INDEX = 1; + static constexpr int C1_INDEX = 2; + static constexpr int HW_INDEX = 3; + + /// Default ctor + CATLASS_HOST_DEVICE + Conv3d6HdCoord() {} + + CATLASS_HOST_DEVICE + Conv3d6HdCoord(Coord<4, Index> const &coord) : Base(coord) {} + + CATLASS_HOST_DEVICE + Conv3d6HdCoord(Index n, Index d, Index c1, Index hw) : Base(MakeCoord(n, d, c1, hw)) {} + + CATLASS_HOST_DEVICE + Index const &n() const + { + return this->At(N_INDEX); + } + CATLASS_HOST_DEVICE + Index &n() + { + return this->At(N_INDEX); + } + + CATLASS_HOST_DEVICE + Index const &d() const + { + return this->At(D_INDEX); + } + CATLASS_HOST_DEVICE + Index &d() + { + return this->At(D_INDEX); + } + + CATLASS_HOST_DEVICE + Index const &c1() const + { + return this->At(C1_INDEX); + } + CATLASS_HOST_DEVICE + Index &c1() + { + return this->At(C1_INDEX); + } + + CATLASS_HOST_DEVICE + Index const &hw() const + { + return this->At(HW_INDEX); + } + CATLASS_HOST_DEVICE + Index &hw() + { + return this->At(HW_INDEX); + } +}; + +struct Conv3dFracZ3dCoord : public Coord<2, uint32_t> { + using Index = uint32_t; + + using Base = Coord<2, Index>; + + static constexpr int KDC1KHKW_INDEX = 0; + static constexpr int N1_INDEX = 1; + + /// Default ctor + CATLASS_HOST_DEVICE + Conv3dFracZ3dCoord() {} + + CATLASS_HOST_DEVICE + Conv3dFracZ3dCoord(Index kdc1khkw, Index n1) : Base(MakeCoord(kdc1khkw, n1)) {} + + CATLASS_HOST_DEVICE + Index const &kdc1khkw() const + { + return this->At(KDC1KHKW_INDEX); + } + CATLASS_HOST_DEVICE + Index &kdc1khkw() + { + return this->At(KDC1KHKW_INDEX); + } + + CATLASS_HOST_DEVICE + Index const &n1() const + { + return this->At(N1_INDEX); + } + CATLASS_HOST_DEVICE + Index &n1() + { + return this->At(N1_INDEX); + } +}; +} // namespace Catlass + +#endif // CATLASS_CONV_COORD_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/coord.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/coord.hpp new file mode 100644 index 00000000..661e096d --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/coord.hpp @@ -0,0 +1,333 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_COORD_HPP +#define CATLASS_COORD_HPP + +#include "catlass/catlass.hpp" + +namespace Catlass { + +/// Statically-sized array specifying Coords within a tensor +template +struct Coord { +public: + // Number of elements in Coord + static const int RANK = RANK_; + + // Index typen used to store elements + using Index = Index_; + + // Type used to represent linear offsets + using LongIndex = LongIndex_; + + // Default ctor initializes uniformly + CATLASS_HOST_DEVICE constexpr explicit Coord(Index value = Index(0)) + { + for (int i = 0; i < RANK; ++i) { + idx[i] = value; + } + } + + // Constructs from an array of integers + CATLASS_HOST_DEVICE constexpr Coord(Index const (&idx_)[RANK]) + { + for (int i = 0; i < RANK; ++i) { + idx[i] = idx_[i]; + } + } + + // Constructs from an array of integers + CATLASS_HOST_DEVICE + int Argmin() const + { + int i = 0; + for (int j = 1; j < RANK; ++j) { + if (idx[j] < idx[i]) { + i = j; + } + } + return i; + } + + // Returns the index of the dimension with greatest value + CATLASS_HOST_DEVICE + int Argmax() const + { + int i = 0; + for (int j = 1; j < RANK; ++j) { + if (idx[j] > idx[i]) { + i = j; + } + } + return i; + } + + // Returns true if Coord is non-zero + CATLASS_HOST_DEVICE + explicit operator bool() const + { + for (int i = 0; i < RANK; ++i) { + if (idx[i]) { + return true; + } + } + return false; + } + + // Return true if Coord is uniformly zero. + CATLASS_HOST_DEVICE + bool operator!() const + { + for (int i = 0; i < RANK; ++i) { + if (idx[i]) { + return false; + } + } + return true; + } + + // Element-wise addition + CATLASS_HOST_DEVICE + Coord operator+(Coord const &b) const + { + Coord c; + for (int i = 0; i < RANK; ++i) { + c.idx[i] = idx[i] + b.idx[i]; + } + return c; + } + + // Add a scalar to each element + CATLASS_HOST_DEVICE + Coord operator+(const Index val) const + { + Coord c; + for (int i = 0; i < RANK; ++i) { + c.idx[i] = idx[i] + val; + } + return c; + } + + // Element-wise subtraction + CATLASS_HOST_DEVICE + Coord operator-(Coord const &b) const + { + Coord c; + for (int i = 0; i < RANK; i++) { + c.idx[i] = idx[i] - b.idx[i]; + } + return c; + } + + // Subtract a scalar from each element + CATLASS_HOST_DEVICE + Coord operator-(Index const val) const + { + Coord c; + for (int i = 0; i < RANK; ++i) { + c.idx[i] = idx[i] - val; + } + return c; + } + + // Element-wise multiply + CATLASS_HOST_DEVICE + Coord operator*(Coord const &b) const + { + Coord c; + for (int i = 0; i < RANK; i++) { + c.idx[i] = idx[i] * b.idx[i]; + } + return c; + } + + // Element-wise division + CATLASS_HOST_DEVICE + Coord operator/(Coord const &b) const + { + Coord c; + for (int i = 0; i < RANK; i++) { + c.idx[i] = idx[i] / b.idx[i]; + } + return c; + } + + // Element-wise mod + CATLASS_HOST_DEVICE + Coord operator%(Coord const &b) const + { + Coord c; + for (int i = 0; i < RANK; i++) { + c.idx[i] = idx[i] % b.idx[i]; + } + return c; + } + + // In-place addition + CATLASS_HOST_DEVICE + Coord &operator+=(Coord const &b) + { + for (int i = 0; i < RANK; ++i) { + idx[i] += b.idx[i]; + } + return *this; + } + + // In-place equal + CATLASS_HOST_DEVICE + bool operator==(Coord const &b) const + { + for (int i = 0; i < RANK; ++i) { + if (idx[i] != b.idx[i]) { + return false; + } + } + return true; + } + + // In-place equal + CATLASS_HOST_DEVICE + bool operator==(Index const val) const + { + for (int i = 0; i < RANK; ++i) { + if (idx[i] != val) { + return false; + } + } + return true; + } + + // Member access operator + CATLASS_HOST_DEVICE + Index &operator[](int dim) + { + return idx[dim]; + } + + // Member access operator + CATLASS_HOST_DEVICE + Index const &operator[](int dim) const + { + return idx[dim]; + } + + // Gets the index of a given Coord element + template + CATLASS_HOST_DEVICE Index &At() + { + return idx[DIM]; + } + + // Access via index; may limit unrolling potential + CATLASS_HOST_DEVICE + Index &At(int dim) + { + return idx[dim]; + } + + // Gets the index of a given Coord element + template + CATLASS_HOST_DEVICE Index const &At() const + { + return idx[DIM]; + } + + // Access via index; may limit unrolling potential + CATLASS_HOST_DEVICE + Index const &At(int dim) const + { + return idx[dim]; + } + + template + CATLASS_HOST_DEVICE auto GetCoordByAxis() const + { + Index idx_[sizeof...(Is)]{idx[Is]...}; + return Coord{idx_}; + } + + CATLASS_HOST_DEVICE + static Coord Min(Coord const &a, Coord const &b) + { + Coord res; + for (int i = 0; i < RANK; ++i) { + res[i] = a[i] < b[i] ? a[i] : b[i]; + } + return res; + } + +private: + // Indices + Index idx[RANK]; +}; + +// Helper to make a 1-element coordinate +template +CATLASS_HOST_DEVICE constexpr Coord<1, T> MakeCoord(T dim0) +{ + T values[1] = {dim0}; + return Coord<1, T>(values); +} + +/// Helper to make a 2-element coordinate +template +CATLASS_HOST_DEVICE constexpr Coord<2, T> MakeCoord(T dim0, T dim1) +{ + T values[2] = {dim0, dim1}; + return Coord<2, T>(values); +} + +/// Helper to make a 3-element coordinate +template +CATLASS_HOST_DEVICE constexpr Coord<3, T> MakeCoord(T dim0, T dim1, T dim2) +{ + T values[3] = {dim0, dim1, dim2}; + return Coord<3, T>(values); +} + +/// Helper to make a 4-element coordinate +template +CATLASS_HOST_DEVICE constexpr Coord<4, T> MakeCoord(T dim0, T dim1, T dim2, T dim3) +{ + T values[4] = {dim0, dim1, dim2, dim3}; + return Coord<4, T>(values); +} + +/// Helper to make a 5-element coordinate +template +CATLASS_HOST_DEVICE constexpr Coord<5, T> MakeCoord(T dim0, T dim1, T dim2, T dim3, T dim4) +{ + T values[5] = {dim0, dim1, dim2, dim3, dim4}; + return Coord<5, T>(values); +} + +/// Helper to make a 6-element coordinate +template +CATLASS_HOST_DEVICE constexpr Coord<6, T> MakeCoord(T dim0, T dim1, T dim2, T dim3, T dim4, T dim5) +{ + T values[6] = {dim0, dim1, dim2, dim3, dim4, dim5}; + return Coord<6, T>(values); +} + +/// Helper to make a 7-element coordinate +template +CATLASS_HOST_DEVICE constexpr Coord<7, T> MakeCoord(T dim0, T dim1, T dim2, T dim3, T dim4, T dim5, T dim6) +{ + T values[7] = {dim0, dim1, dim2, dim3, dim4, dim5, dim6}; + return Coord<7, T>(values); +} + +} // namespace Catlass + +#endif // CATLASS_COORD_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/debug.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/debug.hpp new file mode 100644 index 00000000..32fffb0a --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/debug.hpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_DEBUG_HPP +#define CATLASS_DEBUG_HPP + +#pragma push_macro("inline") +#include +#include +#include +#pragma pop_macro("inline") + +#include +#include + +#define SINGLE_CORE_DUMPSIZE (1024 * 1024) +// 75 is from AscendC host stub +#define ALL_DUMPSIZE (75 * SINGLE_CORE_DUMPSIZE) + +using LogFuncType = std::function; +/** + * @brief Check acl api status code. + * @param status The return code of acl api. + * @param logFunc Log function, which receives a C-Style string. + * @return + */ +inline void aclCheck(aclError status, LogFuncType logFunc = [](const char *logStrPtr) { std::cerr << logStrPtr; }) +{ + if (status != ACL_SUCCESS) { + std::stringstream ss; + ss << "AclError: " << status; + logFunc(ss.str().c_str()); + } +} +/** + * @brief Check rt api status code. + * @param status The return code of rt api. + * @param logFunc Log function, which receives a C-Style string. + * @return + */ +inline void rtCheck(rtError_t status, LogFuncType logFunc = [](const char *logStrPtr) { std::cerr << logStrPtr; }) +{ + if (status != RT_ERROR_NONE) { + std::stringstream ss; + ss << "RtError: " << status; + logFunc(ss.str().c_str()); + } +} + +namespace Adx { +void AdumpPrintWorkSpace(const void *dumpBufferAddr, const size_t dumpBufferSize, aclrtStream stream, + const char *opType); +} // namespace Adx + +#endif // CATLASS_DEBUG_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/detail/alignment.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/detail/alignment.hpp new file mode 100644 index 00000000..d6177c5f --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/detail/alignment.hpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_ALIGNMENT_HPP +#define CATLASS_ALIGNMENT_HPP + +#include "catlass/detail/macros.hpp" +#include "tla/numeric/integral_constant.hpp" + +template +CATLASS_HOST_DEVICE constexpr T RoundUp(const T &val) +{ + static_assert(ALIGN != 0, "ALIGN must not be 0"); + return (val + ALIGN - 1) / ALIGN * ALIGN; +} + +template +CATLASS_HOST_DEVICE constexpr auto RoundUp(T const &val, U const &align) +{ + if constexpr (tla::is_static::value && tla::is_static::value) { // Int, Int + constexpr uint32_t res = (T::value + U::value - 1) / U::value * U::value; + return tla::Int{}; + } else if constexpr (tla::is_static::value) { // Int, int + return (T::value + align - 1) / align * align; + } else if constexpr (tla::is_static::value) { // int, Int + return (val + U::value - 1) / U::value * U::value; + } else { // int, int + return (val + align - 1) / align * align; + } +} + +template +CATLASS_HOST_DEVICE constexpr T RoundDown(const T val) +{ + static_assert(ALIGN != 0, "ALIGN must not be 0"); + return val / ALIGN * ALIGN; +} + +template +CATLASS_HOST_DEVICE constexpr auto RoundDown(T const &val, U const &align) +{ + if constexpr (tla::is_static::value && tla::is_static::value) { // Int, Int + constexpr uint32_t res = T::value / U::value * U::value; + return tla::Int{}; + } else if constexpr (tla::is_static::value) { // Int, int + return T::value / align * align; + } else if constexpr (tla::is_static::value) { // int, Int + return val / U::value * U::value; + } else { // int, int + return val / align * align; + } +} + +template +CATLASS_HOST_DEVICE constexpr T CeilDiv(const T dividend) +{ + static_assert(DIVISOR != 0, "DIVISOR must not be 0"); + return (dividend + DIVISOR - 1) / DIVISOR; +} + +template +CATLASS_HOST_DEVICE constexpr auto CeilDiv(T const ÷nd, U const &divisor) +{ + if constexpr (tla::is_static::value && tla::is_static::value) { // Int, Int + constexpr uint32_t res = (T::value + U::value - 1) / U::value; + return tla::Int{}; + } else if constexpr (tla::is_static::value) { // Int, int + return (T::value + divisor - 1) / divisor; + } else if constexpr (tla::is_static::value) { // int, Int + return (dividend + U::value - 1) / U::value; + } else { // int, int + return (dividend + divisor - 1) / divisor; + } +} + +#endif // CATLASS_ALIGNMENT_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/detail/callback.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/detail/callback.hpp new file mode 100644 index 00000000..83d80bc3 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/detail/callback.hpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_DETAIL_CALLBACK_HPP +#define CATLASS_DETAIL_CALLBACK_HPP + +#include "catlass/detail/macros.hpp" + +/// @brief Callback is an alternative to std::function, providing a general carrier +/// of callable structure with no parameters and no return value. Compared with function pointers +/// of type void (*)(), Callback can carry lambda expressions with captures, and does not need to +/// pay attention to the captured content. It should be noted that Callback itself does not store +/// the callable structure it carries like std::function, so it is necessary to ensure +/// that it is used within the life cycle of the callable structure. +struct Callback { + void const *func{nullptr}; + void (*caller)(void const *){nullptr}; + + Callback() = default; + + CATLASS_DEVICE + void operator()() const + { + if (func) { + caller(func); + } + } + + CATLASS_DEVICE + operator bool() const + { + return func != nullptr; + } +}; + +template +CATLASS_DEVICE void FuncWrapper(void const *func) +{ + (*static_cast(func))(); +} + +// Use this to make a callback +template +CATLASS_DEVICE Callback MakeCallback(Func *func) +{ + Callback callback; + callback.func = func; + callback.caller = &FuncWrapper; + return callback; +} + +#endif // CATLASS_DETAIL_CALLBACK_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/detail/dependent_false.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/detail/dependent_false.hpp new file mode 100644 index 00000000..f7e3b083 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/detail/dependent_false.hpp @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_DETAIL_DEPENDENT_FALSE_HPP +#define CATLASS_DETAIL_DEPENDENT_FALSE_HPP + +template +constexpr bool DEPENDENT_BOOL_VALUE = VALUE; + +template +constexpr bool DEPENDENT_FALSE = DEPENDENT_BOOL_VALUE; + +#endif // CATLASS_DETAIL_DEPENDENT_FALSE_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/detail/macros.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/detail/macros.hpp new file mode 100644 index 00000000..8a903b79 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/detail/macros.hpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_DETAIL_MACROS_HPP +#define CATLASS_DETAIL_MACROS_HPP + +#if defined(__CCE__) +#include +#endif + +#define CATLASS_DEVICE __forceinline__ __aicore__ +#ifdef __CCE__ +#define CATLASS_HOST_DEVICE __forceinline__[host, aicore] +#else +#define CATLASS_HOST_DEVICE +#endif +#define CATLASS_GLOBAL __global__ __aicore__ + +#endif // CATLASS_DETAIL_MACROS_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/detail/tag_to_layout.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/detail/tag_to_layout.hpp new file mode 100644 index 00000000..30e08923 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/detail/tag_to_layout.hpp @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_DETAIL_TAG_TO_LAYOUT_HPP +#define CATLASS_DETAIL_TAG_TO_LAYOUT_HPP + +#include "catlass/layout/layout.hpp" +#include "tla/layout.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace Catlass::detail { +//////////////////////////////////////////////////////////////////////////////////////////////////// +// For each Catlass::layout, provides its corresponding tla layout types +template +struct TagToLayout { + using type = LayoutTag; +}; + +template +struct TagToLayout { + using type = tla::Layout, tla::Stride>>; +}; + +template +struct TagToLayout { + using type = tla::Layout, tla::Stride, int64_t>>; +}; + +template +struct TagToLayout { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + using type = tla::Layout< + tla::Shape, uint32_t>, tla::Shape, uint32_t>>, + tla::Stride, tla::Int>, + tla::Stride, int64_t>>>; +}; + +template +struct TagToLayout { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + using type = tla::Layout< + tla::Shape, uint32_t>, tla::Shape, uint32_t>>, + tla::Stride, int64_t>, + tla::Stride, tla::Int>>>; +}; + +template +struct TagToLayout { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + using type = tla::Layout< + tla::Shape, uint32_t>, tla::Shape, uint32_t>>, + tla::Stride, int64_t>, + tla::Stride, tla::Int>>>; +}; + +// Convenience aliases +template +using TagToLayout_t = typename TagToLayout::type; + +constexpr uint32_t ELE_NUM_PER_FRACTAL_L0C = 256; +using LayoutL0C = tla::Layout< + tla::Shape, uint32_t>, tla::Shape, uint32_t>>, + tla::Stride, tla::Int>, + tla::Stride, int64_t>>>; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Catlass::detail + +#endif // CATLASS_DETAIL_TAG_TO_LAYOUT_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue.hpp new file mode 100644 index 00000000..1f4ccc54 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue.hpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_HPP +#define CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_HPP + +#include "catlass/catlass.hpp" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue +{ + static_assert(DEPENDENT_FALSE, "Could not find an epilogue specialization"); +}; + +} // namespace Catlass::Epilogue::Block + +#include "catlass/epilogue/block/block_epilogue_elemwise_no_source.hpp" +#include "catlass/epilogue/block/block_epilogue_elemwise_one_source.hpp" +#include "catlass/epilogue/block/block_epilogue_fa_softmax.hpp" +#include "catlass/epilogue/block/block_epilogue_fa_rescale_o.hpp" +#include "catlass/epilogue/block/block_epilogue_mla_softmax.hpp" +#include "catlass/epilogue/block/block_epilogue_mla_rescale_o.hpp" +#include "catlass/epilogue/block/block_epilogue_mla_fd_rescale_o.hpp" +#include "catlass/epilogue/block/block_epilogue_per_token_dequant.hpp" +#include "catlass/epilogue/block/block_epilogue_gemm.hpp" +#include "catlass/epilogue/block/block_epilogue_gemv.hpp" +#include "catlass/epilogue/block/block_epilogue_mla_tp1_softmax.hpp" +#include "catlass/epilogue/block/block_epilogue_mla_tp1_rescale_o.hpp" +#include "catlass/epilogue/block/block_epilogue_online_softmax_no_mask.hpp" +#include "catlass/epilogue/block/block_epilogue_rescale_o_no_split_row.hpp" +#endif // CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_elemwise_no_source.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_elemwise_no_source.hpp new file mode 100644 index 00000000..78bdcec3 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_elemwise_no_source.hpp @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_ELEMWISE_NO_SOURCE_HPP +#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_ELEMWISE_NO_SOURCE_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/layout/layout.hpp" + +namespace Catlass::Epilogue::Block { +// 部分特化:当DispatchPolicy为EpilogueAtlasA2ElemWiseNoSource时的特化版本 +template +class BlockEpilogue +{ +public: + // Type aliases + using DispatchPolicy = EpilogueAtlasA2ElemWiseNoSource; + using ArchTag = typename DispatchPolicy::ArchTag; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementX = typename CType_::Element; // X是fp32的计算结果,无GM + using LayoutX = typename CType_::Layout; + + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + using TileElemWiseEpilogue = TileElemWiseEpilogue_; + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + static constexpr uint32_t COMPUTE_LENGTH = TileElemWiseEpilogue::COMPUTE_LENGTH; + static constexpr uint32_t OPERANDS_NUM = DispatchPolicy::OPERANDS_NUM; + + using ElementCompute = ElementC; + using ElementOut = ElementD; + + using LayoutComputeInUb = layout::RowMajor; + + // Check the element type of C and D + static_assert(std::is_same_v, "Element type of C must be float"); + // Check the layout type of C and D + static_assert(std::is_same_v && std::is_same_v, + "Layout type of C, D must be RowMajor"); + + // Check if ArchTag is matched + static_assert(std::is_same_v, "Tile epilogue's ArchTag mismatch"); + // Check if compute length is valid + static_assert(COMPUTE_LENGTH * (OPERANDS_NUM * sizeof(ElementC) + sizeof(ElementD)) <= ArchTag::UB_SIZE, + "UB out of bounds"); + + // Epilogue params definition + struct Params { + GM_ADDR ptrC; + LayoutC layoutC; + GM_ADDR ptrD; + LayoutD layoutD; + + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GM_ADDR ptrC_, LayoutD const &layoutC_, GM_ADDR ptrD_, LayoutD const &layoutD_) + : ptrC(ptrC_), layoutC(layoutC_), ptrD(ptrD_), layoutD(layoutD_) + {} + }; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource &resource, Params const ¶ms) : params(params) + { + ubC = resource.ubBuf.template GetBufferByByte(0); + ubX = resource.ubBuf.template GetBufferByByte(COMPUTE_LENGTH * sizeof(ElementC)); + ubD = resource.ubBuf.template GetBufferByByte(COMPUTE_LENGTH * sizeof(ElementC) * 2); + + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID0); + } + + CATLASS_DEVICE + ~BlockEpilogue() + { + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + + CATLASS_DEVICE + void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK, + GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor const &gmBlockC, + LayoutD const &layoutBlockC) + { + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + + // Calculate the offset and the shape of the current subblock + MatrixCoord subblockShape{CeilDiv(actualBlockShape.row(), static_cast(AscendC::GetSubBlockNum())), + actualBlockShape.column()}; + + MatrixCoord subblockCoord{AscendC::GetSubBlockIdx(), 0}; + MatrixCoord actualSubblockShape = + MatrixCoord::Min(subblockShape, actualBlockShape - subblockCoord * subblockShape); + MatrixCoord subblockOffset = subblockCoord * subblockShape; + + // Get the data and layout of C + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrC)); + auto gmSubblockC = gmBlockC[params.layoutC.GetOffset(subblockOffset)]; + auto layoutSubblockC = params.layoutC.GetTileLayout(actualSubblockShape); + + // Get the data and layout of D + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD *>(params.ptrD)); + auto gmSubblockD = gmD[params.layoutD.GetOffset(blockOffset + subblockOffset)]; + auto layoutSubblockD = params.layoutD.GetTileLayout(actualSubblockShape); + + // Get the layout on UB + auto layoutComputeInUb = LayoutComputeInUb::template MakeLayoutInUb(actualSubblockShape); + auto layoutComputeOutUb = LayoutComputeInUb::template MakeLayoutInUb(actualSubblockShape); + // Copy the data of C + AscendC::WaitFlag(EVENT_ID0); + copyGmToUbC(ubC, gmSubblockC, layoutComputeInUb, layoutSubblockC); + AscendC::SetFlag(EVENT_ID0); + + // Perform epilogue calculation + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + tileEpilogue(ubX, ubC); + AscendC::Cast(ubD, ubX, AscendC::RoundMode::CAST_RINT, COMPUTE_LENGTH); + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID0); + + // Copy the data of D + AscendC::WaitFlag(EVENT_ID0); + copyUbToGmD(gmSubblockD, ubD, layoutSubblockD, layoutComputeOutUb); + AscendC::SetFlag(EVENT_ID0); + } + +private: + Params params; + + AscendC::LocalTensor ubC; + AscendC::LocalTensor ubX; + AscendC::LocalTensor ubD; + + TileElemWiseEpilogue tileEpilogue; + CopyGmToUbC copyGmToUbC; + CopyUbToGmD copyUbToGmD; +}; + +} // namespace Catlass::Epilogue::Block + +#endif // CATLASS_EPILOGUE_BLOCK_EPILOGUE_ELEMWISE_NO_SOURCE_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_elemwise_one_source.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_elemwise_one_source.hpp new file mode 100644 index 00000000..fdc745f0 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_elemwise_one_source.hpp @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_ELEMWISE_ONE_SOURCE_HPP +#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_ELEMWISE_ONE_SOURCE_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/layout/layout.hpp" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue +{ +public: + // Type aliases + using DispatchPolicy = EpilogueAtlasA2ElemWiseOneSource; + using ArchTag = typename DispatchPolicy::ArchTag; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementX = typename XType_::Element; + using LayoutX = typename XType_::Layout; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + using TileElemWiseEpilogue = TileElemWiseEpilogue_; + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbX = typename TileCopy_::CopyGmToUbX; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + static constexpr uint32_t COMPUTE_LENGTH = TileElemWiseEpilogue::COMPUTE_LENGTH; + static constexpr uint32_t OPERANDS_NUM = DispatchPolicy::OPERANDS_NUM; + + // Check the element type of C, X and D + static_assert(std::is_same_v && std::is_same_v, + "Element type of C, X and D must be the same"); + using ElementCompute = ElementD; + + // Check the layout type of C, X and D + static_assert(std::is_same_v && std::is_same_v && + std::is_same_v, + "Layout type of C, X and D must be RowMajor"); + using LayoutComputeInUb = layout::RowMajor; + + // Check if ArchTag is matched + static_assert(std::is_same_v, "Tile epilogue's ArchTag mismatch"); + // Check if compute length is valid + static_assert(COMPUTE_LENGTH * OPERANDS_NUM * sizeof(ElementCompute) <= ArchTag::UB_SIZE, "UB out of bounds"); + + // Epilogue params definition + struct Params { + GM_ADDR ptrX; + LayoutX layoutX; + GM_ADDR ptrD; + LayoutD layoutD; + + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GM_ADDR ptrX_, LayoutX const &layoutX_, GM_ADDR ptrD_, LayoutD const &layoutD_) + : ptrX(ptrX_), layoutX(layoutX_), ptrD(ptrD_), layoutD(layoutD_) + {} + }; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource &resource, Params const ¶ms) : params(params) + { + ubC = resource.ubBuf.template GetBufferByByte(0); + ubX = resource.ubBuf.template GetBufferByByte(COMPUTE_LENGTH * sizeof(ElementC)); + ubD = resource.ubBuf.template GetBufferByByte(COMPUTE_LENGTH * sizeof(ElementC) + + COMPUTE_LENGTH * sizeof(ElementX)); + + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID0); + } + + CATLASS_DEVICE + ~BlockEpilogue() + { + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + + CATLASS_DEVICE + void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK, + GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor const &gmBlockC, + LayoutX const &layoutBlockC) + { + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + + // Calculate the offset and the shape of the current subblock + MatrixCoord subblockShape{CeilDiv(actualBlockShape.row(), static_cast(AscendC::GetSubBlockNum())), + actualBlockShape.column()}; + MatrixCoord subblockCoord{AscendC::GetSubBlockIdx(), 0}; + MatrixCoord actualSubblockShape = + MatrixCoord::Min(subblockShape, actualBlockShape - subblockCoord * subblockShape); + MatrixCoord subblockOffset = subblockCoord * subblockShape; + + // Get the data and layout of C + auto gmSubblockC = gmBlockC[layoutBlockC.GetOffset(subblockOffset)]; + auto layoutSubblockC = layoutBlockC.GetTileLayout(actualSubblockShape); + + // Get the data and layout of X + AscendC::GlobalTensor gmX; + gmX.SetGlobalBuffer(reinterpret_cast<__gm__ ElementX *>(params.ptrX)); + auto gmSubblockX = gmX[params.layoutX.GetOffset(blockOffset + subblockOffset)]; + auto layoutSubblockX = params.layoutX.GetTileLayout(actualSubblockShape); + + // Get the data and layout of D + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD *>(params.ptrD)); + auto gmSubblockD = gmD[params.layoutD.GetOffset(blockOffset + subblockOffset)]; + auto layoutSubblockD = params.layoutD.GetTileLayout(actualSubblockShape); + + // Get the layout on UB + auto layoutComputeInUb = LayoutComputeInUb::template MakeLayoutInUb(actualSubblockShape); + + // Copy the data of C and X + AscendC::WaitFlag(EVENT_ID0); + copyGmToUbC(ubC, gmSubblockC, layoutComputeInUb, layoutSubblockC); + copyGmToUbX(ubX, gmSubblockX, layoutComputeInUb, layoutSubblockX); + AscendC::SetFlag(EVENT_ID0); + + // Perform epilogue calculation + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + tileEpilogue(ubD, ubC, ubX); + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID0); + + // Copy the data of D + AscendC::WaitFlag(EVENT_ID0); + copyUbToGmD(gmSubblockD, ubD, layoutSubblockD, layoutComputeInUb); + AscendC::SetFlag(EVENT_ID0); + } + +private: + Params params; + + AscendC::LocalTensor ubC; + AscendC::LocalTensor ubX; + AscendC::LocalTensor ubD; + + TileElemWiseEpilogue tileEpilogue; + CopyGmToUbC copyGmToUbC; + CopyGmToUbX copyGmToUbX; + CopyUbToGmD copyUbToGmD; +}; + +} // namespace Catlass::Epilogue::Block + +#endif // CATLASS_EPILOGUE_BLOCK_EPILOGUE_ELEMWISE_ONE_SOURCE_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_fa_rescale_o.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_fa_rescale_o.hpp new file mode 100644 index 00000000..b58690a9 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_fa_rescale_o.hpp @@ -0,0 +1,279 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_FA_RESCALE_O_HPP +#define CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_FA_RESCALE_O_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/epilogue/tile/tile_copy.hpp" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue +{ +public: + // Type aliases + using DispatchPolicy = EpilogueAtlasA2FARescaleO; + using ArchTag = typename DispatchPolicy::ArchTag; + using ElementOutput = typename OutputType_::Element; + using ElementInput = typename InputType_::Element; + + using LayoutOutput = typename OutputType_::Layout; + using LayoutInput = typename InputType_::Layout; + + using CopyGmToUbInput = Tile::CopyGm2Ub; + using CopyUbToGmOutput = Tile::CopyUb2Gm; + + static constexpr uint32_t HALF_ELENUM_PER_BLK = 16; + static constexpr uint32_t FLOAT_ELENUM_PER_BLK = 8; + static constexpr uint32_t HALF_ELENUM_PER_VECCALC = 128; + static constexpr uint32_t FLOAT_ELENUM_PER_VECCALC = 64; + static constexpr uint32_t UB_TILE_SIZE = 16384; // 64 * 128 * 2B + static constexpr uint32_t UB_LINE_SIZE = 512; // 128 * 2 * 2B + static constexpr uint32_t HALF_ELENUM_PER_LINE = 256; // 128 * 2 + static constexpr uint32_t FLOAT_ELENUM_PER_LINE = 128; // 128 + static constexpr uint32_t MULTIPLIER = 2; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource &resource) + { + constexpr uint32_t LO_UB_TENSOR_OFFSET = 5 * UB_TILE_SIZE; + constexpr uint32_t DM_UB_TENSOR_OFFSET = 7 * UB_TILE_SIZE + 4 * UB_LINE_SIZE; + constexpr uint32_t LL_UB_TENSOR_OFFSET = 7 * UB_TILE_SIZE + 6 * UB_LINE_SIZE; + constexpr uint32_t GL_UB_TENSOR_OFFSET = 7 * UB_TILE_SIZE + 10 * UB_LINE_SIZE; + constexpr uint32_t TV_UB_TENSOR_OFFSET = 7 * UB_TILE_SIZE + 11 * UB_LINE_SIZE; + constexpr uint32_t GO_UB_TENSOR_OFFSET = 8 * UB_TILE_SIZE; + + loUbTensor = resource.ubBuf.template GetBufferByByte(LO_UB_TENSOR_OFFSET); + dmUbTensor = resource.ubBuf.template GetBufferByByte(DM_UB_TENSOR_OFFSET); + llUbTensor = resource.ubBuf.template GetBufferByByte(LL_UB_TENSOR_OFFSET); + glUbTensor = resource.ubBuf.template GetBufferByByte(GL_UB_TENSOR_OFFSET); + tvUbTensor = resource.ubBuf.template GetBufferByByte(TV_UB_TENSOR_OFFSET); + goUbTensor = resource.ubBuf.template GetBufferByByte(GO_UB_TENSOR_OFFSET); + + AscendC::SetFlag(EVENT_ID1); + AscendC::SetFlag(EVENT_ID0); + } + + CATLASS_DEVICE + ~BlockEpilogue() + { + AscendC::WaitFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID0); + } + + CATLASS_DEVICE + void SetMask(int32_t len) + { + const int32_t MAX_MASK_LEN = 128; + const int32_t HALF_MASK_LEN = 64; + if (len >= MAX_MASK_LEN) { + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + return; + } + int32_t highMask = len - HALF_MASK_LEN > 0 ? len - HALF_MASK_LEN : 0; + int32_t lowMask = len - HALF_MASK_LEN >= 0 ? HALF_MASK_LEN : len; + if (len < HALF_MASK_LEN) { + AscendC::SetVectorMask(0x0, ((uint64_t)1 << lowMask) - 1); + } else { + AscendC::SetVectorMask(((uint64_t)1 << highMask) - 1, 0xffffffffffffffff); + } + } + + CATLASS_DEVICE + void subCoreCompute(AscendC::GlobalTensor gOutput, AscendC::GlobalTensor gInput, + const LayoutOutput &layoutOutput, const LayoutInput &layoutInput, uint32_t nIdx, + uint32_t isLast) + { + uint32_t subM = layoutInput.shape(0); + uint32_t k = layoutInput.shape(1); + uint32_t kRound = layoutInput.stride(0); + uint32_t strideQO = layoutOutput.stride(0); + uint32_t subMAligned128 = (subM + HALF_ELENUM_PER_VECCALC - 1) / HALF_ELENUM_PER_VECCALC; + uint32_t subMAligned64 = (subM + FLOAT_ELENUM_PER_VECCALC - 1) / FLOAT_ELENUM_PER_VECCALC; + uint32_t subMRound = (subM + HALF_ELENUM_PER_BLK - 1) / HALF_ELENUM_PER_BLK * HALF_ELENUM_PER_BLK; + + // Get the layout on UB + LayoutInput layoutInUb(subM, k, kRound); + + if (subM > 0) { + AscendC::WaitFlag(EVENT_ID1); + // Copy O + copyGmToUbInput(loUbTensor, gInput, layoutInUb, layoutInput); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + // 更新 L 和 O + if (nIdx != 0) { + // dm32 = castfp16to32(dm) + AscendC::Cast(tvUbTensor, dmUbTensor[nIdx % MULTIPLIER * HALF_ELENUM_PER_LINE], + AscendC::RoundMode::CAST_NONE, (uint64_t)0, subMAligned64, + AscendC::UnaryRepeatParams(1, 1, 8, 4)); + AscendC::PipeBarrier(); + // dm32_block = brcb(dm32) + AscendC::Brcb(tvUbTensor.ReinterpretCast()[HALF_ELENUM_PER_VECCALC], + tvUbTensor.ReinterpretCast(), subMRound / FLOAT_ELENUM_PER_BLK, + AscendC::BrcbRepeatParams(1, 8)); + AscendC::PipeBarrier(); + // dm32 = exp(dm32) + AscendC::Exp(tvUbTensor, tvUbTensor, (uint64_t)0, subMAligned64, + AscendC::UnaryRepeatParams(1, 1, 8, 8)); + AscendC::PipeBarrier(); + // gl = dm * gl + AscendC::Mul(glUbTensor, tvUbTensor, glUbTensor, (uint64_t)0, subMAligned64, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + // gl = ll + gl + AscendC::Add(glUbTensor, glUbTensor, + llUbTensor[nIdx % MULTIPLIER * FLOAT_ELENUM_PER_LINE], (uint64_t)0, + subMAligned64, AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + // dm32_block = exp(dm32_block) + AscendC::Exp( + tvUbTensor[HALF_ELENUM_PER_VECCALC], tvUbTensor[HALF_ELENUM_PER_VECCALC], (uint64_t)0, + (subM * FLOAT_ELENUM_PER_BLK + FLOAT_ELENUM_PER_VECCALC - 1) / FLOAT_ELENUM_PER_VECCALC, + AscendC::UnaryRepeatParams(1, 1, 8, 8)); + AscendC::PipeBarrier(); + if (goFlag == 1) { + AscendC::WaitFlag(EVENT_ID0); + goFlag = 0; + } + // go = go * dm32_block + for (uint32_t vmulIdx = 0; vmulIdx < k / FLOAT_ELENUM_PER_VECCALC; vmulIdx++) { + AscendC::Mul(goUbTensor[vmulIdx * FLOAT_ELENUM_PER_VECCALC], + goUbTensor[vmulIdx * FLOAT_ELENUM_PER_VECCALC], + tvUbTensor[HALF_ELENUM_PER_VECCALC], (uint64_t)0, subM, + AscendC::BinaryRepeatParams(1, 1, 0, kRound / FLOAT_ELENUM_PER_BLK, + kRound / FLOAT_ELENUM_PER_BLK, 1)); + AscendC::PipeBarrier(); + } + if (k % FLOAT_ELENUM_PER_VECCALC > 0) { + SetMask(k % FLOAT_ELENUM_PER_VECCALC); + AscendC::Mul(goUbTensor[k / FLOAT_ELENUM_PER_VECCALC * FLOAT_ELENUM_PER_VECCALC], + goUbTensor[k / FLOAT_ELENUM_PER_VECCALC * FLOAT_ELENUM_PER_VECCALC], + tvUbTensor[HALF_ELENUM_PER_VECCALC], (uint64_t)0, subM, + AscendC::BinaryRepeatParams(1, 1, 0, kRound / FLOAT_ELENUM_PER_BLK, + kRound / FLOAT_ELENUM_PER_BLK, 1)); + AscendC::PipeBarrier(); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + // go = lo + go + AscendC::Add(goUbTensor, goUbTensor, loUbTensor, (uint64_t)0, + (subM * kRound + FLOAT_ELENUM_PER_VECCALC - 1) / FLOAT_ELENUM_PER_VECCALC, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + } else { + // gl = ll + AscendC::DataCopy(glUbTensor, llUbTensor[nIdx % MULTIPLIER * FLOAT_ELENUM_PER_LINE], + AscendC::DataCopyParams(1, subMRound / FLOAT_ELENUM_PER_BLK, 0, 0)); + AscendC::PipeBarrier(); + if (goFlag == 1) { + AscendC::WaitFlag(EVENT_ID0); + goFlag = 0; + } + AscendC::DataCopy(goUbTensor, loUbTensor, + AscendC::DataCopyParams(1, subM * kRound / FLOAT_ELENUM_PER_BLK, 0, 0)); + AscendC::PipeBarrier(); + } + AscendC::SetFlag(EVENT_ID1); + if (isLast) { + // gl = castfp32to16(gl) + AscendC::Cast(glUbTensor.ReinterpretCast(), glUbTensor, + AscendC::RoundMode::CAST_NONE, (uint64_t)0, subMAligned64, + AscendC::UnaryRepeatParams(1, 1, 4, 8)); + AscendC::PipeBarrier(); + // go = castfp32to16(go) + AscendC::Cast( + goUbTensor.ReinterpretCast(), goUbTensor, AscendC::RoundMode::CAST_NONE, (uint64_t)0, + (subM * kRound + FLOAT_ELENUM_PER_VECCALC - 1) / FLOAT_ELENUM_PER_VECCALC, + AscendC::UnaryRepeatParams(1, 1, 4, 8)); + AscendC::PipeBarrier(); + // gl_block = brcb(gl) + AscendC::Brcb(tvUbTensor.ReinterpretCast(), glUbTensor.ReinterpretCast(), + subMRound / FLOAT_ELENUM_PER_BLK, AscendC::BrcbRepeatParams(1, 8)); + AscendC::PipeBarrier(); + // go = go / gl_block + for (uint32_t vdivIdx = 0; vdivIdx < k / HALF_ELENUM_PER_VECCALC; vdivIdx++) { + AscendC::Div(goUbTensor.ReinterpretCast()[vdivIdx * HALF_ELENUM_PER_VECCALC], + goUbTensor.ReinterpretCast()[vdivIdx * HALF_ELENUM_PER_VECCALC], + tvUbTensor.ReinterpretCast(), (uint64_t)0, subM, + AscendC::BinaryRepeatParams(1, 1, 0, kRound / HALF_ELENUM_PER_BLK, + kRound / HALF_ELENUM_PER_BLK, 1)); + } + if (k % HALF_ELENUM_PER_VECCALC > 0) { + SetMask(k % HALF_ELENUM_PER_VECCALC); + AscendC::Div( + goUbTensor.ReinterpretCast()[k / HALF_ELENUM_PER_VECCALC * HALF_ELENUM_PER_VECCALC], + goUbTensor.ReinterpretCast()[k / HALF_ELENUM_PER_VECCALC * HALF_ELENUM_PER_VECCALC], + tvUbTensor.ReinterpretCast(), (uint64_t)0, subM, + AscendC::BinaryRepeatParams(1, 1, 0, kRound / HALF_ELENUM_PER_BLK, kRound / HALF_ELENUM_PER_BLK, + 1)); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + // copy O to GM + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + copyUbToGmOutput(gOutput, goUbTensor.ReinterpretCast(), layoutOutput, layoutInUb); + if (goFlag == 0) { + AscendC::SetFlag(EVENT_ID0); + goFlag = 1; + } + } + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gOutput, AscendC::GlobalTensor gInput, + const LayoutOutput &layoutOutput, const LayoutInput &layoutInput, GemmCoord actualBlockShape, + uint32_t nIdx, uint32_t isLast) + { + uint32_t mActual = actualBlockShape.m(); + uint32_t nActual = actualBlockShape.n(); + + uint32_t subBlockIdx = AscendC::GetSubBlockIdx(); + uint32_t subBlockNum = AscendC::GetSubBlockNum(); + + uint32_t mActualPerSubBlock = CeilDiv(mActual, subBlockNum); + uint32_t mActualThisSubBlock = (subBlockIdx == 0) ? mActualPerSubBlock : (mActual - mActualPerSubBlock); + uint32_t mOffset = subBlockIdx * mActualPerSubBlock; + uint32_t nOffset = 0; + + int64_t offsetOutput = layoutOutput.GetOffset(MatrixCoord(mOffset, nOffset)); + auto gOutputThisSubBlock = gOutput[offsetOutput]; + auto layoutOutputThisSubBlock = layoutOutput.GetTileLayout(MatrixCoord(mActualThisSubBlock, nActual)); + + int64_t offsetInput = layoutInput.GetOffset(MatrixCoord(mOffset, nOffset)); + auto gInputThisSubBlock = gInput[offsetInput]; + auto layoutInputThisSubBlock = layoutInput.GetTileLayout(MatrixCoord(mActualThisSubBlock, nActual)); + + subCoreCompute(gOutputThisSubBlock, gInputThisSubBlock, layoutOutputThisSubBlock, layoutInputThisSubBlock, nIdx, + isLast); + } + +private: + uint32_t goFlag = 1; + AscendC::LocalTensor loUbTensor; + AscendC::LocalTensor dmUbTensor; + AscendC::LocalTensor llUbTensor; + AscendC::LocalTensor glUbTensor; + AscendC::LocalTensor tvUbTensor; + AscendC::LocalTensor goUbTensor; + + CopyGmToUbInput copyGmToUbInput; + CopyUbToGmOutput copyUbToGmOutput; +}; + +} // namespace Catlass::Epilogue::Block + +#endif // CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_FA_RESCALE_O_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_fa_softmax.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_fa_softmax.hpp new file mode 100644 index 00000000..776ab7b7 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_fa_softmax.hpp @@ -0,0 +1,323 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_FA_SOFTMAX_HPP +#define CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_FA_SOFTMAX_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/epilogue/tile/tile_copy.hpp" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue +{ +public: + // Type aliases + using DispatchPolicy = EpilogueAtlasA2FASoftmax; + using ArchTag = typename DispatchPolicy::ArchTag; + using ElementOutput = typename OutputType_::Element; + using ElementInput = typename InputType_::Element; + using ElementMask = typename MaskType_::Element; + + using LayoutOutput = typename OutputType_::Layout; + using LayoutInput = typename InputType_::Layout; + using LayoutMask = typename MaskType_::Layout; + + using CopyGmToUbInput = Tile::CopyGm2Ub; + using CopyGmToUbMask = Tile::CopyGm2Ub; + using CopyUbToGmOutput = Tile::CopyUb2Gm; + + static constexpr uint32_t HALF_ELENUM_PER_BLK = 16; + static constexpr uint32_t FLOAT_ELENUM_PER_BLK = 8; + static constexpr uint32_t HALF_ELENUM_PER_VECCALC = 128; + static constexpr uint32_t FLOAT_ELENUM_PER_VECCALC = 64; + static constexpr uint32_t UB_TILE_SIZE = 16384; // 64 * 128 * 2B + static constexpr uint32_t UB_LINE_SIZE = 512; // 128 * 2 * 2B + static constexpr uint32_t HALF_ELENUM_PER_LINE = 256; // 128 * 2 + static constexpr uint32_t FLOAT_ELENUM_PER_LINE = 128; // 128 + static constexpr uint32_t MULTIPLIER = 2; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource &resource, half tor_) + { + constexpr uint32_t LS32_UB_TENSOR_OFFSET = 2 * UB_TILE_SIZE; + constexpr uint32_t MASK_UB_TENSOR_OFFSET = 4 * UB_TILE_SIZE; + constexpr uint32_t LM_UB_TENSOR_OFFSET = 7 * UB_TILE_SIZE; + constexpr uint32_t HM_UB_TENSOR_OFFSET = 7 * UB_TILE_SIZE + 1 * UB_LINE_SIZE; + constexpr uint32_t GM_UB_TENSOR_OFFSET = 7 * UB_TILE_SIZE + 2 * UB_LINE_SIZE; + constexpr uint32_t DM_UB_TENSOR_OFFSET = 7 * UB_TILE_SIZE + 4 * UB_LINE_SIZE; + constexpr uint32_t LL_UB_TENSOR_OFFSET = 7 * UB_TILE_SIZE + 6 * UB_LINE_SIZE; + constexpr uint32_t TV_UB_TENSOR_OFFSET = 7 * UB_TILE_SIZE + 11 * UB_LINE_SIZE; + + tor = tor_; + lsUbTensor = resource.ubBuf.template GetBufferByByte(0); + lpUbTensor = resource.ubBuf.template GetBufferByByte(0); + ls32UbTensor = resource.ubBuf.template GetBufferByByte(LS32_UB_TENSOR_OFFSET); + maskUbTensor = resource.ubBuf.template GetBufferByByte(MASK_UB_TENSOR_OFFSET); + lmUbTensor = resource.ubBuf.template GetBufferByByte(LM_UB_TENSOR_OFFSET); + hmUbTensor = resource.ubBuf.template GetBufferByByte(HM_UB_TENSOR_OFFSET); + gmUbTensor = resource.ubBuf.template GetBufferByByte(GM_UB_TENSOR_OFFSET); + dmUbTensor = resource.ubBuf.template GetBufferByByte(DM_UB_TENSOR_OFFSET); + llUbTensor = resource.ubBuf.template GetBufferByByte(LL_UB_TENSOR_OFFSET); + tvUbTensor = resource.ubBuf.template GetBufferByByte(TV_UB_TENSOR_OFFSET); + + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + } + + CATLASS_DEVICE + ~BlockEpilogue() + { + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + } + + CATLASS_DEVICE + void SetMask(int32_t len) + { + const int32_t MAX_MASK_LEN = 128; + const int32_t HALF_MASK_LEN = 64; + if (len >= MAX_MASK_LEN) { + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + return; + } + int32_t highMask = len - HALF_MASK_LEN > 0 ? len - HALF_MASK_LEN : 0; + int32_t lowMask = len - HALF_MASK_LEN >= 0 ? HALF_MASK_LEN : len; + if (len < HALF_MASK_LEN) { + AscendC::SetVectorMask(0x0, ((uint64_t)1 << lowMask) - 1); + } else { + AscendC::SetVectorMask(((uint64_t)1 << highMask) - 1, 0xffffffffffffffff); + } + } + + CATLASS_DEVICE + void SetVcgMask(int32_t len) + { + const int32_t MAX_LEN = 16; + if (len > MAX_LEN) { + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + return; + } + uint64_t subMask = ((uint64_t)1 << len) - 1; + uint64_t maskValue = (subMask << 48) + (subMask << 32) + (subMask << 16) + subMask; + AscendC::SetVectorMask(maskValue, maskValue); + } + + CATLASS_DEVICE + void subCoreCompute(AscendC::GlobalTensor gOutput, AscendC::GlobalTensor gInput, + AscendC::GlobalTensor gMask, const LayoutOutput &layoutOutput, + const LayoutInput &layoutInput, const LayoutMask &layoutMask, uint32_t nIdx, + Arch::CrossCoreFlag qkReady) + { + uint32_t subM = layoutInput.shape(0); + uint32_t qkN = layoutInput.shape(1); + uint32_t qkNRound = layoutInput.stride(0); + uint32_t maxSeqlen = layoutMask.stride(0); + uint32_t offset = pingpongFlag * UB_TILE_SIZE / sizeof(ElementInput); + uint32_t subMAligned128 = (subM + HALF_ELENUM_PER_VECCALC - 1) / HALF_ELENUM_PER_VECCALC; + uint32_t subMRound = (subM + HALF_ELENUM_PER_BLK - 1) / HALF_ELENUM_PER_BLK * HALF_ELENUM_PER_BLK; + + // Get the layout on UB + auto layoutInUb = LayoutInput::template MakeLayoutInUb(MatrixCoord{subM, qkN}); + + if (subM > 0) { + // Copy mask + AscendC::WaitFlag(EVENT_ID0); + copyGmToUbMask(maskUbTensor, gMask, layoutInUb, layoutMask); + } + Arch::CrossCoreWaitFlag(qkReady); + if (subM > 0) { + AscendC::WaitFlag(pingpongFlag); + // Copy QK + copyGmToUbInput(lsUbTensor[offset], gInput, layoutInUb, layoutInput); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + // ls = tor * ls + AscendC::Muls(lsUbTensor[offset], lsUbTensor[offset], tor, (uint64_t)0, + (subM * qkNRound + HALF_ELENUM_PER_VECCALC - 1) / HALF_ELENUM_PER_VECCALC, + AscendC::UnaryRepeatParams(1, 1, 8, 8)); + AscendC::PipeBarrier(); + // ls = ls + mask + AscendC::Add(lsUbTensor[offset], lsUbTensor[offset], maskUbTensor, (uint64_t)0, + (subM * qkNRound + HALF_ELENUM_PER_VECCALC - 1) / HALF_ELENUM_PER_VECCALC, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + AscendC::SetFlag(EVENT_ID0); + // lm = rowmax(ls) + if (qkN <= HALF_ELENUM_PER_VECCALC) { + SetMask(qkN); + AscendC::BlockReduceMax(tvUbTensor.ReinterpretCast(), lsUbTensor[offset], subM, 0, 2, + 1, qkNRound / HALF_ELENUM_PER_BLK); + AscendC::PipeBarrier(); + SetVcgMask(qkNRound / HALF_ELENUM_PER_BLK); + AscendC::BlockReduceMax( + lmUbTensor, tvUbTensor.ReinterpretCast(), + (subM * HALF_ELENUM_PER_BLK + HALF_ELENUM_PER_VECCALC - 1) / HALF_ELENUM_PER_VECCALC, 0, 1, 1, 8); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + AscendC::PipeBarrier(); + if (nIdx == 0) { + // hm = lm + AscendC::DataCopy(hmUbTensor, lmUbTensor, + AscendC::DataCopyParams(1, subMRound / HALF_ELENUM_PER_BLK, 0, 0)); + AscendC::PipeBarrier(); + } else { + // hm = vmax(lm, gm) + AscendC::Max(hmUbTensor, lmUbTensor, gmUbTensor, (uint64_t)0, subMAligned128, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + // dm = gm - hm + AscendC::Sub(dmUbTensor[nIdx % MULTIPLIER * HALF_ELENUM_PER_LINE], gmUbTensor, hmUbTensor, + (uint64_t)0, subMAligned128, AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + } + // gm = hm + AscendC::DataCopy(gmUbTensor, hmUbTensor, + AscendC::DataCopyParams(1, subMRound / HALF_ELENUM_PER_BLK, 0, 0)); + AscendC::PipeBarrier(); + // hm_block = brcb(hm), 存放于tv + AscendC::Brcb(tvUbTensor.ReinterpretCast(), hmUbTensor.ReinterpretCast(), + subMRound / FLOAT_ELENUM_PER_BLK, AscendC::BrcbRepeatParams(1, 8)); + AscendC::PipeBarrier(); + // ls = ls - hm_block + for (uint32_t vsubIdx = 0; vsubIdx < qkN / HALF_ELENUM_PER_VECCALC; vsubIdx++) { + AscendC::Sub(lsUbTensor[offset + vsubIdx * HALF_ELENUM_PER_VECCALC], + lsUbTensor[offset + vsubIdx * HALF_ELENUM_PER_VECCALC], + tvUbTensor.ReinterpretCast(), (uint64_t)0, subM, + AscendC::BinaryRepeatParams(1, 1, 0, qkNRound / HALF_ELENUM_PER_BLK, + qkNRound / HALF_ELENUM_PER_BLK, 1)); + } + if (qkN % HALF_ELENUM_PER_VECCALC > 0) { + SetMask(qkN % HALF_ELENUM_PER_VECCALC); + AscendC::Sub(lsUbTensor[offset + qkN / HALF_ELENUM_PER_VECCALC * HALF_ELENUM_PER_VECCALC], + lsUbTensor[offset + qkN / HALF_ELENUM_PER_VECCALC * HALF_ELENUM_PER_VECCALC], + tvUbTensor.ReinterpretCast(), (uint64_t)0, subM, + AscendC::BinaryRepeatParams(1, 1, 0, qkNRound / HALF_ELENUM_PER_BLK, + qkNRound / HALF_ELENUM_PER_BLK, 1)); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + AscendC::PipeBarrier(); + // ls32 = castfp16to32(ls) + AscendC::Cast( + ls32UbTensor, lsUbTensor[offset], AscendC::RoundMode::CAST_NONE, (uint64_t)0, + (subM * qkNRound + FLOAT_ELENUM_PER_VECCALC - 1) / FLOAT_ELENUM_PER_VECCALC, + AscendC::UnaryRepeatParams(1, 1, 8, 4)); + AscendC::PipeBarrier(); + // ls32 = exp(ls32) + AscendC::Exp(ls32UbTensor, ls32UbTensor, (uint64_t)0, + (subM * qkNRound + FLOAT_ELENUM_PER_VECCALC - 1) / FLOAT_ELENUM_PER_VECCALC, + AscendC::UnaryRepeatParams(1, 1, 8, 8)); + AscendC::PipeBarrier(); + // lp = castfp32to16(ls) + AscendC::Cast( + lpUbTensor[offset], ls32UbTensor, AscendC::RoundMode::CAST_NONE, (uint64_t)0, + (subM * qkNRound + FLOAT_ELENUM_PER_VECCALC - 1) / FLOAT_ELENUM_PER_VECCALC, + AscendC::UnaryRepeatParams(1, 1, 4, 8)); + AscendC::PipeBarrier(); + AscendC::SetFlag(EVENT_ID0); + // ll = rowsum(ls32) + if (qkN <= FLOAT_ELENUM_PER_VECCALC) { + SetMask(qkN); + AscendC::RepeatReduceSum(llUbTensor[nIdx % MULTIPLIER * FLOAT_ELENUM_PER_LINE], + ls32UbTensor, subM, 0, 0, 1, 1, qkNRound / FLOAT_ELENUM_PER_BLK); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } else { + for (uint32_t vaddIdx = 1; vaddIdx < qkN / FLOAT_ELENUM_PER_VECCALC; vaddIdx++) { + AscendC::Add( + ls32UbTensor, ls32UbTensor, ls32UbTensor[vaddIdx * FLOAT_ELENUM_PER_VECCALC], (uint64_t)0, subM, + AscendC::BinaryRepeatParams(1, 1, 1, qkNRound / FLOAT_ELENUM_PER_BLK, + qkNRound / FLOAT_ELENUM_PER_BLK, qkNRound / FLOAT_ELENUM_PER_BLK)); + AscendC::PipeBarrier(); + } + if (qkN % FLOAT_ELENUM_PER_VECCALC > 0) { + SetMask(qkN % FLOAT_ELENUM_PER_VECCALC); + AscendC::Add( + ls32UbTensor, ls32UbTensor, + ls32UbTensor[qkN / FLOAT_ELENUM_PER_VECCALC * FLOAT_ELENUM_PER_VECCALC], (uint64_t)0, subM, + AscendC::BinaryRepeatParams(1, 1, 1, qkNRound / FLOAT_ELENUM_PER_BLK, + qkNRound / FLOAT_ELENUM_PER_BLK, qkNRound / FLOAT_ELENUM_PER_BLK)); + AscendC::PipeBarrier(); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + AscendC::RepeatReduceSum(llUbTensor[nIdx % MULTIPLIER * FLOAT_ELENUM_PER_LINE], + ls32UbTensor, subM, 0, 0, 1, 1, qkNRound / FLOAT_ELENUM_PER_BLK); + } + AscendC::PipeBarrier(); + AscendC::WaitFlag(EVENT_ID0); + copyUbToGmOutput(gOutput, lpUbTensor[offset], layoutOutput, layoutInUb); + AscendC::SetFlag(pingpongFlag); + pingpongFlag = 1 - pingpongFlag; + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gOutput, AscendC::GlobalTensor gInput, + AscendC::GlobalTensor gMask, const LayoutOutput &layoutOutput, + const LayoutInput &layoutInput, const LayoutMask &layoutMask, GemmCoord actualBlockShape, + uint32_t nIdx, Arch::CrossCoreFlag qkReady) + { + uint32_t mActual = actualBlockShape.m(); + uint32_t nActual = actualBlockShape.n(); + + uint32_t subBlockIdx = AscendC::GetSubBlockIdx(); + uint32_t subBlockNum = AscendC::GetSubBlockNum(); + + uint32_t mActualPerSubBlock = CeilDiv(mActual, subBlockNum); + uint32_t mActualThisSubBlock = (subBlockIdx == 0) ? mActualPerSubBlock : (mActual - mActualPerSubBlock); + uint32_t mOffset = subBlockIdx * mActualPerSubBlock; + uint32_t nOffset = 0; + + int64_t offsetOutput = layoutOutput.GetOffset(MatrixCoord(mOffset, nOffset)); + auto gOutputThisSubBlock = gOutput[offsetOutput]; + auto layoutOutputThisSubBlock = layoutOutput.GetTileLayout(MatrixCoord(mActualThisSubBlock, nActual)); + + int64_t offsetInput = layoutInput.GetOffset(MatrixCoord(mOffset, nOffset)); + auto gInputThisSubBlock = gInput[offsetInput]; + auto layoutInputThisSubBlock = layoutInput.GetTileLayout(MatrixCoord(mActualThisSubBlock, nActual)); + + int64_t offsetMask = layoutMask.GetOffset(MatrixCoord(mOffset, nOffset)); + auto gMaskThisSubBlock = gMask[offsetMask]; + auto layoutMaskThisSubBlock = layoutMask.GetTileLayout(MatrixCoord(mActualThisSubBlock, nActual)); + + subCoreCompute(gOutputThisSubBlock, gInputThisSubBlock, gMaskThisSubBlock, layoutOutputThisSubBlock, + layoutInputThisSubBlock, layoutMaskThisSubBlock, nIdx, qkReady); + } + +private: + half tor; + uint32_t pingpongFlag = 0; + AscendC::LocalTensor lsUbTensor; + AscendC::LocalTensor lpUbTensor; + AscendC::LocalTensor ls32UbTensor; + AscendC::LocalTensor maskUbTensor; + AscendC::LocalTensor lmUbTensor; + AscendC::LocalTensor hmUbTensor; + AscendC::LocalTensor gmUbTensor; + AscendC::LocalTensor dmUbTensor; + AscendC::LocalTensor llUbTensor; + AscendC::LocalTensor tvUbTensor; + + CopyGmToUbInput copyGmToUbInput; + CopyGmToUbMask copyGmToUbMask; + CopyUbToGmOutput copyUbToGmOutput; +}; + +} // namespace Catlass::Epilogue::Block + +#endif // CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_FA_SOFTMAX_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_gemm.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_gemm.hpp new file mode 100644 index 00000000..3d6dd8c3 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_gemm.hpp @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_GEMM_HPP +#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_GEMM_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/epilogue/tile/tile_copy.hpp" +#include "catlass/gemm/helper.hpp" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue +{ +public: + // Type aliases + using DispatchPolicy = EpilogueAtlasA2Gemm; + using ArchTag = typename DispatchPolicy::ArchTag; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementX = typename XType_::Element; + using LayoutX = typename XType_::Layout; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + using TileElemWiseEpilogueAdd = TileElemWiseEpilogueAdd_; + using TileElemWiseEpilogueMuls = TileElemWiseEpilogueMuls_; + using TileElemWiseCastD = TileElemWiseCastD_; + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbX = typename TileCopy_::CopyGmToUbX; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + const uint32_t SubNum = AscendC::GetSubBlockNum(); + const uint32_t UBSize = ArchTag::UB_SIZE; + static constexpr bool isNeedCast = !std::is_same::value; + static constexpr uint32_t COMPUTE_LENGTH = TileElemWiseEpilogueAdd::COMPUTE_LENGTH; + + using ElementCompute = + typename Catlass::Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using ElementScalar = ElementCompute; + + // Check if ArchTag is matched + static_assert(std::is_same_v, + "Tile epilogue's ArchTag mismatch"); + static_assert(std::is_same_v, + "Tile epilogue's ArchTag mismatch"); + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + struct Params { + ElementScalar alpha; + ElementScalar beta; + GM_ADDR ptrX; + LayoutX layoutX; + GM_ADDR ptrD; + LayoutD layoutD; + + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(ElementScalar alpha_, ElementScalar beta_, GM_ADDR ptrX_, LayoutX layoutX_, GM_ADDR ptrD_, + LayoutD layoutD_) + : alpha(alpha_), beta(beta_), ptrX(ptrX_), layoutX(layoutX_), ptrD(ptrD_), layoutD(layoutD_) + {} + }; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource &resource, GemmCoord blockShape_, Params const ¶ms_, + uint32_t ubByteStart = 0) + : blockShapeMNK(blockShape_), params(params_) + { + uint32_t maxMPerBlock = blockShapeMNK.m(); + uint32_t maxNPerBlock = blockShapeMNK.n(); + uint32_t tileSize = maxMPerBlock * maxNPerBlock / SubNum; + uint32_t ubCSize = tileSize * sizeof(ElementC); + uint32_t ubXSize = tileSize * sizeof(ElementX); + uint32_t ubDSize = tileSize * sizeof(ElementD); + uint32_t ubXCastSize = tileSize * sizeof(ElementCompute); + uint32_t ubDCastSize = tileSize * sizeof(ElementCompute); + ubCTensor = resource.ubBuf.template GetBufferByByte(ubByteStart); + ubByteStart += ubCSize; + ubXTensor = resource.ubBuf.template GetBufferByByte(ubByteStart); + ubByteStart += ubXSize; + ubDTensor = resource.ubBuf.template GetBufferByByte(ubByteStart); + ubByteStart += ubDSize; + if constexpr (isNeedCast) { + ubXTensorCast = resource.ubBuf.template GetBufferByByte(ubByteStart); + ubByteStart += ubXCastSize; + ubDTensorCast = resource.ubBuf.template GetBufferByByte(ubByteStart); + ; + ubByteStart += ubDCastSize; + } + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + AscendC::SetFlag(EVENT_ID0); + } + + CATLASS_DEVICE + ~BlockEpilogue() + { + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID0); + } + + CATLASS_DEVICE + void operator()(GemmCoord const &actualShapeMNK, GemmCoord const &blockCoordMNK, + AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutC, uint64_t const &offset) + { + AscendC::GlobalTensor gmBlockX; + gmBlockX.SetGlobalBuffer(reinterpret_cast<__gm__ ElementX *>(params.ptrX)); + AscendC::GlobalTensor gmBlockD; + gmBlockD.SetGlobalBuffer(reinterpret_cast<__gm__ ElementD *>(params.ptrD)); + MatrixCoord blockShapeMN = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoordMN = blockCoordMNK.GetCoordMN(); + MatrixCoord actualShapeMN = actualShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoordMN * blockShapeMN; + MatrixCoord subblockShape{CeilDiv(actualShapeMN.row(), SubNum), actualShapeMN.column()}; + MatrixCoord subblockCoord{AscendC::GetSubBlockIdx(), 0}; + MatrixCoord actualSubblockShape = + MatrixCoord::Min(subblockShape, actualShapeMN - subblockCoord * subblockShape); + MatrixCoord subblockOffset = subblockCoord * subblockShape; + LayoutC layoutInUb{blockShapeMN.row() / SubNum, blockShapeMN.column()}; + AscendC::WaitFlag(EVENT_ID1); + auto layoutTileX = params.layoutX.GetTileLayout(actualSubblockShape); + auto layoutXInUb = layoutInUb.GetTileLayout(actualSubblockShape); + auto gmTileX = gmBlockX[offset + params.layoutX.GetOffset(blockOffset + subblockOffset)]; + copyGmToUbX(ubXTensor, gmTileX, layoutXInUb, layoutTileX); + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + if constexpr (isNeedCast) { + AscendC::Cast(ubXTensorCast, ubXTensor, AscendC::RoundMode::CAST_NONE, + COMPUTE_LENGTH); + AscendC::PipeBarrier(); + tileElemWiseEpilogueMuls(ubXTensorCast, ubXTensorCast, (ElementCompute)params.beta); + } else { + tileElemWiseEpilogueMuls(ubXTensor, ubXTensor, (ElementX)params.beta); + } + AscendC::WaitFlag(EVENT_ID0); + auto layoutTileC = layoutC.GetTileLayout(actualSubblockShape); + auto layoutCInUb = layoutInUb.GetTileLayout(actualSubblockShape); + auto gmTileC = gmBlockC[offset + layoutC.GetOffset(blockOffset + subblockOffset)]; + copyGmToUbC(ubCTensor, gmTileC, layoutCInUb, layoutTileC); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + tileElemWiseEpilogueMuls(ubCTensor, ubCTensor, (ElementC)params.alpha); + AscendC::WaitFlag(EVENT_ID0); + AscendC::PipeBarrier(); + if constexpr (isNeedCast) { + tileElemWiseEpilogueAdd(ubDTensorCast, ubCTensor, ubXTensorCast); + } else { + tileElemWiseEpilogueAdd(ubDTensor, ubCTensor, ubXTensor); + } + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + AscendC::PipeBarrier(); + if constexpr (isNeedCast) { + tileElemWiseCastD(ubDTensor, ubDTensorCast); + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + auto layoutDInGm = params.layoutD.GetTileLayout(actualSubblockShape); + auto layoutTileD = layoutInUb.GetTileLayout(actualSubblockShape); + auto gmTileD = gmBlockD[offset + params.layoutD.GetOffset(blockOffset + subblockOffset)]; + copyUbToGmD(gmTileD, ubDTensor, layoutDInGm, layoutTileD); + AscendC::SetFlag(EVENT_ID0); + } + +private: + GemmCoord blockShapeMNK; + Params params; + + AscendC::LocalTensor ubCTensor; + AscendC::LocalTensor ubXTensor; + AscendC::LocalTensor ubDTensor; + AscendC::LocalTensor ubXTensorCast; + AscendC::LocalTensor ubDTensorCast; + + CopyGmToUbC copyGmToUbC; + CopyGmToUbX copyGmToUbX; + CopyUbToGmD copyUbToGmD; + + TileElemWiseEpilogueAdd tileElemWiseEpilogueAdd; + TileElemWiseEpilogueMuls tileElemWiseEpilogueMuls; + TileElemWiseCastD tileElemWiseCastD; +}; +} // namespace Catlass::Epilogue::Block + +#endif // CATLASS_EPILOGUE_BLOCK_EPILOGUE_GEMM_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_gemv.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_gemv.hpp new file mode 100644 index 00000000..4b24635d --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_gemv.hpp @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_GEMV_HPP +#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_GEMV_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemv/helper.hpp" +#include "catlass/gemv_coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue +{ +public: + using DispatchPolicy = EpilogueAtlasA2Gemv; + using ArchTag = typename DispatchPolicy::ArchTag; + + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementY = typename YType_::Element; + using LayoutY = typename YType_::Layout; + using ElementZ = typename ZType_::Element; + using LayoutZ = typename ZType_::Layout; + + using TileElemWiseEpilogueAdd = TileElemWiseEpilogueAdd_; + using TileElemWiseEpilogueMuls = TileElemWiseEpilogueMuls_; + + using CopyGmToUbY = typename TileCopy_::CopyGmToUbC; + using CopyGmToubC = typename TileCopy_::CopyGmToUbX; + using CopyUbToGmZ = typename TileCopy_::CopyUbToGmD; + + static constexpr uint32_t COMPUTE_LENGTH = TileElemWiseEpilogueMuls::COMPUTE_LENGTH; + + static constexpr bool isNeedCast = !std::is_same::value; + + using ElementCompute = + typename Catlass::Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using ElementScalar = ElementCompute; + using TensorCoord = layout::VectorLayout::TensorCoord; + + // check the layout of Y, C and Z + static_assert(std::is_same_v && std::is_same_v && + std::is_same_v, + "Layout type of Y, C and Z must be VectorLayout"); + + using LayoutComputeInUb = layout::VectorLayout; + + // Check if ArchTag is matched + static_assert(std::is_same_v, + "Tile epilogue's ArchTag mismatch"); + + struct Params { + ElementScalar alpha; + ElementScalar beta; + GM_ADDR ptrY; + LayoutY layoutY; + GM_ADDR ptrZ; + LayoutZ layoutZ; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(ElementScalar alpha_, ElementScalar beta_, GM_ADDR ptrY_, LayoutC layoutY_, GM_ADDR ptrZ_, + LayoutZ layoutZ_) + : alpha(alpha_), beta(beta_), ptrY(ptrY_), layoutY(layoutY_), ptrZ(ptrZ_), layoutZ(layoutZ_) + {} + }; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource &resource, Params const ¶ms) : params(params) + { + ubC = resource.ubBuf.template GetBufferByByte(0); + ubY = resource.ubBuf.template GetBufferByByte(COMPUTE_LENGTH * sizeof(ElementC)); + ubYCast = resource.ubBuf.template GetBufferByByte(COMPUTE_LENGTH * sizeof(ElementC)); + ubZ = resource.ubBuf.template GetBufferByByte(COMPUTE_LENGTH * sizeof(ElementY) + + COMPUTE_LENGTH * sizeof(ElementC)); + ubZCast = resource.ubBuf.template GetBufferByByte(COMPUTE_LENGTH * sizeof(ElementY) + + COMPUTE_LENGTH * sizeof(ElementC)); + AscendC::SetFlag(EVENT_ID0); + } + + CATLASS_DEVICE + ~BlockEpilogue() + { + AscendC::WaitFlag(EVENT_ID0); + } + + CATLASS_DEVICE + void operator()(TensorCoord const &blockOffsetMN, TensorCoord const &actualBlockShapeMN, + AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutBlockC) + { + TensorCoord actualBlockShape = actualBlockShapeMN; + TensorCoord blockOffset = blockOffsetMN; + + TensorCoord subblockShape{CeilDiv(actualBlockShape[0], static_cast(AscendC::GetSubBlockNum()))}; + TensorCoord subblockCoord{static_cast(AscendC::GetSubBlockIdx())}; + + TensorCoord actualSubblockShape = + TensorCoord::Min(subblockShape, actualBlockShape - subblockCoord * subblockShape); + TensorCoord subblockOffset = subblockCoord * subblockShape; + + // Get the data and layout of C + auto gmSubblockC = gmBlockC[layoutBlockC.GetOffset(subblockOffset)]; + auto layoutSubblockC = layoutBlockC.GetTileLayout(actualSubblockShape); + + // Get the data and layout of y + AscendC::GlobalTensor gmY; + gmY.SetGlobalBuffer(reinterpret_cast<__gm__ ElementY *>(params.ptrY)); + auto gmSubblockY = gmY[params.layoutY.GetOffset(blockOffset + subblockOffset)]; + auto layoutSubblockY = params.layoutY.GetTileLayout(actualSubblockShape); + + // Get the data and layout of Z + AscendC::GlobalTensor gmZ; + gmZ.SetGlobalBuffer(reinterpret_cast<__gm__ ElementZ *>(params.ptrZ)); + auto gmSubblockZ = gmZ[params.layoutZ.GetOffset(blockOffset + subblockOffset)]; + auto layoutSubblockZ = params.layoutZ.GetTileLayout(actualSubblockShape); + + // get the layout on UB + auto layoutComputeInUb = LayoutComputeInUb::template MakeLayoutInUb(actualSubblockShape); + + // load C(A*x) from gm to ub + AscendC::WaitFlag(EVENT_ID0); + copyGmToubC(ubC, gmSubblockC, layoutComputeInUb, layoutSubblockC); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + // compute C * alpha + tileEpilogueMul(ubC, ubC, params.alpha); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + // load Y from gm to ub + copyGmToUbY(ubY, gmSubblockY, layoutComputeInUb, layoutSubblockY); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + // compute Y * beta + if constexpr (isNeedCast) { + AscendC::Cast(ubYCast, ubY, AscendC::RoundMode::CAST_NONE, COMPUTE_LENGTH); + AscendC::PipeBarrier(); + tileEpilogueMul(ubYCast, ubYCast, params.beta); + AscendC::PipeBarrier(); + } else { + tileEpilogueMul(ubY, ubY, params.beta); + AscendC::PipeBarrier(); + } + + if constexpr (isNeedCast) { + tileEpilogueAdd(ubZCast, ubC, ubYCast); + } else { + tileEpilogueAdd(ubZ, ubC, ubY); + } + + if constexpr (isNeedCast) { + AscendC::PipeBarrier(); + AscendC::Cast(ubZ, ubZCast, AscendC::RoundMode::CAST_RINT, COMPUTE_LENGTH); + AscendC::PipeBarrier(); + } + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + copyUbToGmZ(gmSubblockZ, ubZ, layoutSubblockZ, layoutComputeInUb); + AscendC::SetFlag(EVENT_ID0); + }; + +private: + Params params; + + AscendC::LocalTensor ubY; + AscendC::LocalTensor ubYCast; + AscendC::LocalTensor ubC; + AscendC::LocalTensor ubZ; + AscendC::LocalTensor ubZCast; + + TileElemWiseEpilogueAdd tileEpilogueAdd; + TileElemWiseEpilogueMuls tileEpilogueMul; + + CopyGmToUbY copyGmToUbY; + CopyGmToubC copyGmToubC; + CopyUbToGmZ copyUbToGmZ; +}; + +} // namespace Catlass::Epilogue::Block + +#endif // CATLASS_EPILOGUE_BLOCK_EPILOGUE_GEMV_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_mla_fd_rescale_o.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_mla_fd_rescale_o.hpp new file mode 100644 index 00000000..3135d35b --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_mla_fd_rescale_o.hpp @@ -0,0 +1,260 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_MLA_FD_RESCALE_O_HPP +#define CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_MLA_FD_RESCALE_O_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/epilogue/tile/tile_copy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue, OutputType_, InputType_> +{ +public: + // Type aliases + using DispatchPolicy = EpilogueAtlasA2MLAFDRescaleO; + using ArchTag = typename DispatchPolicy::ArchTag; + using ElementOutput = typename OutputType_::Element; + using ElementInput = typename InputType_::Element; + using LayoutOutput = typename OutputType_::Layout; + using LayoutInput = typename InputType_::Layout; + + static constexpr uint32_t KV_SPLIT_MAX = DispatchPolicy::KV_SPLIT_MAX; + static constexpr uint32_t HEADS_PROCESS_MAX = DispatchPolicy::HEADS_PROCESS_MAX; + static constexpr uint32_t COMPUTE_ELE_NUM = DispatchPolicy::COMPUTE_ELE_NUM; + static constexpr uint32_t FLOAT_ELENUM_PER_VECCALC = 64; + static constexpr uint32_t FLOAT_BLOCK_SIZE = 8; + static constexpr uint32_t STAGES = 2; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource &resource, uint32_t kvSplitCoreNum_) + { + kvSplitCoreNum = kvSplitCoreNum_; + + uint32_t ubOffset = 0; + oIn[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += COMPUTE_ELE_NUM * sizeof(float); + oIn[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += COMPUTE_ELE_NUM * sizeof(float); + oTemp[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += COMPUTE_ELE_NUM * sizeof(float); + oTemp[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += COMPUTE_ELE_NUM * sizeof(float); + oSum = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += COMPUTE_ELE_NUM * sizeof(float); + out = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += COMPUTE_ELE_NUM * sizeof(ElementOutput); + lIn = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += KV_SPLIT_MAX * HEADS_PROCESS_MAX * sizeof(float); + lExp = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += KV_SPLIT_MAX * HEADS_PROCESS_MAX * sizeof(float); + lMax = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += HEADS_PROCESS_MAX * sizeof(float); + lSum = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += HEADS_PROCESS_MAX * sizeof(float); + lBrcb[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += HEADS_PROCESS_MAX * FLOAT_BLOCK_SIZE * sizeof(float); + lBrcb[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + AscendC::SetFlag(EVENT_ID2); + } + CATLASS_DEVICE + ~BlockEpilogue() + { + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID2); + } + + CATLASS_DEVICE + void SetMask(int32_t len) + { + constexpr int32_t MAX_MASK_LEN = 128; + constexpr int32_t HALF_MASK_LEN = 64; + if (len >= MAX_MASK_LEN) { + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + return; + } + int32_t highMask = len - HALF_MASK_LEN > 0 ? len - HALF_MASK_LEN : 0; + int32_t lowMask = len - HALF_MASK_LEN >= 0 ? HALF_MASK_LEN : len; + if (len < HALF_MASK_LEN) { + AscendC::SetVectorMask(0x0, ((uint64_t)1 << lowMask) - 1); + } else { + AscendC::SetVectorMask(((uint64_t)1 << highMask) - 1, 0xffffffffffffffff); + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gOutput, AscendC::GlobalTensor gOCoreTmp, + AscendC::GlobalTensor gl, uint32_t actualHeads, uint32_t headsProcess, + uint32_t headSize) + { + uint32_t kvSplitRound = (kvSplitCoreNum + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE * FLOAT_BLOCK_SIZE; + + AscendC::WaitFlag(EVENT_ID2); + AscendC::DataCopyPad(lIn, gl, + AscendC::DataCopyExtParams(actualHeads, kvSplitCoreNum * sizeof(ElementInput), 0, + (KV_SPLIT_MAX - kvSplitCoreNum) / FLOAT_BLOCK_SIZE, 0), + AscendC::DataCopyPadExtParams(false, 0, 0, 0)); + + AscendC::SetFlag(EVENT_ID2); + AscendC::WaitFlag(EVENT_ID2); + + SetMask(kvSplitCoreNum); + AscendC::WholeReduceMax(lMax, lIn, (int32_t)0, actualHeads, 1, 1, 8, + AscendC::ReduceOrder::ORDER_ONLY_VALUE); + AscendC::PipeBarrier(); + + for (uint32_t i = 0; i < kvSplitRound / FLOAT_BLOCK_SIZE; i++) { + AscendC::Brcb( + lExp[i * FLOAT_BLOCK_SIZE], lMax, (headsProcess + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE, + AscendC::BrcbRepeatParams(KV_SPLIT_MAX / FLOAT_BLOCK_SIZE, 8 * KV_SPLIT_MAX / FLOAT_BLOCK_SIZE)); + } + AscendC::PipeBarrier(); + + SetMask(kvSplitCoreNum); + AscendC::Sub(lExp, lIn, lExp, (uint64_t)0, actualHeads, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + + AscendC::Exp(lExp, lExp, (uint64_t)0, actualHeads, AscendC::UnaryRepeatParams(1, 1, 8, 8)); + AscendC::PipeBarrier(); + + AscendC::RepeatReduceSum(lSum, lExp, actualHeads, 0, 0, 1, 1, 8); + AscendC::PipeBarrier(); + + AscendC::Ln(lSum, lSum, (headsProcess + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE * FLOAT_BLOCK_SIZE); + AscendC::PipeBarrier(); + + AscendC::Add(lSum, lSum, lMax, (headsProcess + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE * FLOAT_BLOCK_SIZE); + AscendC::PipeBarrier(); + + for (uint32_t i = 0; i < kvSplitRound / FLOAT_BLOCK_SIZE; i++) { + AscendC::Brcb( + lExp[i * FLOAT_BLOCK_SIZE], lSum, (headsProcess + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE, + AscendC::BrcbRepeatParams(KV_SPLIT_MAX / FLOAT_BLOCK_SIZE, 8 * KV_SPLIT_MAX / FLOAT_BLOCK_SIZE)); + } + AscendC::PipeBarrier(); + + SetMask(kvSplitCoreNum); + AscendC::Sub(lExp, lIn, lExp, (uint64_t)0, actualHeads, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + AscendC::SetFlag(EVENT_ID2); + + AscendC::Exp(lExp, lExp, (uint64_t)0, actualHeads, AscendC::UnaryRepeatParams(1, 1, 8, 8)); + AscendC::PipeBarrier(); + + // preload + AscendC::WaitFlag(EVENT_ID0); + AscendC::DataCopyPad( + oIn[0], gOCoreTmp, + AscendC::DataCopyExtParams(actualHeads, headSize * sizeof(ElementInput), + (kvSplitCoreNum * headSize - headSize) * sizeof(ElementInput), 0, 0), + AscendC::DataCopyPadExtParams(false, 0, 0, 0)); + AscendC::SetFlag(EVENT_ID0); + + AscendC::SetFlag(EVENT_ID2); + AscendC::WaitFlag(EVENT_ID2); + + SetMask(FLOAT_ELENUM_PER_VECCALC); + uint32_t bufferId = 0; + for (uint32_t i = 0; i < kvSplitCoreNum; i++) { + // load next o + if (i < kvSplitCoreNum - 1) { + uint32_t nextBufferId = 1 - bufferId; + AscendC::WaitFlag(oInEventList[nextBufferId]); + AscendC::DataCopyPad( + oIn[nextBufferId], gOCoreTmp[(i + 1) * headSize], + AscendC::DataCopyExtParams(actualHeads, headSize * sizeof(ElementInput), + (kvSplitCoreNum * headSize - headSize) * sizeof(ElementInput), 0, 0), + AscendC::DataCopyPadExtParams(false, 0, 0, 0)); + AscendC::SetFlag(oInEventList[nextBufferId]); + } + + AscendC::PipeBarrier(); + for (uint32_t j = 0; j < actualHeads; j++) { + float a = lExp[j * KV_SPLIT_MAX + i].GetValue(0); + AscendC::SetFlag(oTempEventList[bufferId]); + AscendC::WaitFlag(oTempEventList[bufferId]); + AscendC::Duplicate(lBrcb[bufferId][j * FLOAT_BLOCK_SIZE], a, uint64_t(0), 1, 0, 0); + } + AscendC::PipeBarrier(); + + // calculate current o + AscendC::WaitFlag(oInEventList[bufferId]); + uint32_t loops = (headSize + FLOAT_ELENUM_PER_VECCALC - 1) / FLOAT_ELENUM_PER_VECCALC; + if (i > 0) { + for (uint32_t j = 0; j < loops; j++) { + AscendC::Mul(oTemp[bufferId][j * FLOAT_ELENUM_PER_VECCALC], lBrcb[bufferId], + oIn[bufferId][j * FLOAT_ELENUM_PER_VECCALC], (uint64_t)0, actualHeads, + AscendC::BinaryRepeatParams(1, 0, 1, headSize / FLOAT_BLOCK_SIZE, 1, + headSize / FLOAT_BLOCK_SIZE)); + } + } else { + for (uint32_t j = 0; j < loops; j++) { + AscendC::Mul(oSum[j * FLOAT_ELENUM_PER_VECCALC], lBrcb[bufferId], + oIn[bufferId][j * FLOAT_ELENUM_PER_VECCALC], (uint64_t)0, actualHeads, + AscendC::BinaryRepeatParams(1, 0, 1, headSize / FLOAT_BLOCK_SIZE, 1, + headSize / FLOAT_BLOCK_SIZE)); + } + } + AscendC::PipeBarrier(); + AscendC::SetFlag(oInEventList[bufferId]); + + if (i > 0) { + AscendC::Add(oSum, oSum, oTemp[bufferId], actualHeads * headSize); + } + AscendC::PipeBarrier(); + bufferId = 1 - bufferId; + } + + AscendC::WaitFlag(EVENT_ID0); + if (std::is_same::value) { + AscendC::Cast(out, oSum, AscendC::RoundMode::CAST_RINT, actualHeads * headSize); + } else { + AscendC::Cast(out, oSum, AscendC::RoundMode::CAST_NONE, actualHeads * headSize); + } + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::DataCopyPad(gOutput, out, + AscendC::DataCopyExtParams(actualHeads, headSize * sizeof(ElementOutput), 0, 0, 0)); + AscendC::SetFlag(EVENT_ID0); + } + +private: + uint32_t kvSplitCoreNum = 1; + AscendC::LocalTensor out; + AscendC::LocalTensor oIn[STAGES]; + AscendC::LocalTensor oTemp[STAGES]; + AscendC::LocalTensor lBrcb[STAGES]; + AscendC::LocalTensor oSum; + AscendC::LocalTensor lIn; + AscendC::LocalTensor lExp; + AscendC::LocalTensor lTrans; + AscendC::LocalTensor lMax; + AscendC::LocalTensor lSum; + + int32_t oTempEventList[STAGES] = {0, 1}; + int32_t oInEventList[STAGES] = {0, 1}; +}; +} // namespace Catlass::Epilogue::Block +#endif // CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_MLA_FD_RESCALE_O_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_mla_rescale_o.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_mla_rescale_o.hpp new file mode 100644 index 00000000..4e511ef4 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_mla_rescale_o.hpp @@ -0,0 +1,383 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_MLA_RESCALE_O_HPP +#define CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_MLA_RESCALE_O_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/epilogue/tile/tile_copy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue +{ +public: + // Type aliases + using DispatchPolicy = EpilogueAtlasA2MLARescaleO; + using ArchTag = typename DispatchPolicy::ArchTag; + + using ElementOutput = typename OutputType_::Element; + using ElementUpdate = typename UpdateType_::Element; + using ElementInput = typename InputType_::Element; + + using LayoutOutput = typename OutputType_::Layout; + using LayoutUpdate = typename UpdateType_::Layout; + using LayoutInput = typename InputType_::Layout; + + static constexpr uint32_t HALF_ELENUM_PER_BLK = 16; + static constexpr uint32_t HALF_ELENUM_PER_VECCALC = 128; + static constexpr uint32_t FLOAT_ELENUM_PER_VECCALC = 64; + static constexpr uint32_t HALF_ELENUM_PER_LINE = 256; + static constexpr uint32_t FLOAT_ELENUM_PER_LINE = 128; + static constexpr uint32_t MULTIPLIER = 2; + static constexpr uint32_t FLOAT_BLOCK_SIZE = 8; + static constexpr uint32_t FLOAT_VECTOR_SIZE = 64; + static constexpr uint32_t UB_UINT8_LINE_SIZE = 512; + static constexpr uint32_t UB_UINT8_BLOCK_SIZE_MLA = 16384; + static constexpr uint32_t ROW_WISE_CYCLE_TILE = 8; + static constexpr uint32_t HALF_DM_UB_SIZE = 128; + static constexpr uint32_t HALF_LL_UB_SIZE = 256; + static constexpr uint32_t VECTOR_SIZE = 128; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource &resource, uint32_t kvSplitCoreNum_) + { + // Allocate UB space + constexpr uint32_t LO_UB_TENSOR_OFFSET = 4 * UB_UINT8_BLOCK_SIZE_MLA; + constexpr uint32_t DM_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA + 6 * UB_UINT8_LINE_SIZE; + constexpr uint32_t LL_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA + 9 * UB_UINT8_LINE_SIZE; + constexpr uint32_t GL_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA + 15 * UB_UINT8_LINE_SIZE; + constexpr uint32_t GO_UB_TENSOR_OFFSET = 8 * UB_UINT8_BLOCK_SIZE_MLA; + constexpr uint32_t TV_UB_TENSOR_OFFSET = 10 * UB_UINT8_BLOCK_SIZE_MLA; + constexpr uint32_t HM_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA + 1 * UB_UINT8_LINE_SIZE; + constexpr uint32_t GM_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA + 13 * UB_UINT8_LINE_SIZE; + + kvSplitCoreNum = kvSplitCoreNum_; + loUbTensor = resource.ubBuf.template GetBufferByByte(LO_UB_TENSOR_OFFSET); + dmUbTensor = resource.ubBuf.template GetBufferByByte(DM_UB_TENSOR_OFFSET); + llUbTensor = resource.ubBuf.template GetBufferByByte(LL_UB_TENSOR_OFFSET); + glUbTensor = resource.ubBuf.template GetBufferByByte(GL_UB_TENSOR_OFFSET); + tvUbTensor = resource.ubBuf.template GetBufferByByte(TV_UB_TENSOR_OFFSET); + goUbTensor32 = resource.ubBuf.template GetBufferByByte(GO_UB_TENSOR_OFFSET); + goUbTensor16 = resource.ubBuf.template GetBufferByByte(GO_UB_TENSOR_OFFSET); + hmUbTensor = resource.ubBuf.template GetBufferByByte(HM_UB_TENSOR_OFFSET); + gmUbTensor = resource.ubBuf.template GetBufferByByte(GM_UB_TENSOR_OFFSET); + } + + CATLASS_DEVICE + ~BlockEpilogue() {} + + CATLASS_DEVICE + void SetMask(int32_t len) + { + uint64_t mask = 0; + uint64_t one = 1; + uint64_t temp = len % FLOAT_VECTOR_SIZE; + for (int64_t i = 0; i < temp; i++) { + mask |= one << i; + } + if (len == VECTOR_SIZE) { + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } else if (len >= FLOAT_VECTOR_SIZE) { + AscendC::SetVectorMask(mask, (uint64_t)-1); + } else { + AscendC::SetVectorMask(0x0, mask); + } + } + + CATLASS_DEVICE + void SubCoreCompute(AscendC::GlobalTensor gInput, AscendC::GlobalTensor gUpdate, + AscendC::GlobalTensor gOutput, AscendC::GlobalTensor gOCoreTmp, + AscendC::GlobalTensor gl, const LayoutInput &layoutInput, + const LayoutOutput &layoutOutput, const LayoutUpdate &layoutUpdate, uint32_t nIdx, + uint32_t isLastNTile, uint32_t needRowLoop, uint32_t rowLoopIdx, uint32_t proTokenIdx, + uint32_t proTokenNum, uint32_t epiTokenNum, uint32_t integralHeadNum, + uint32_t rescaleOPingPongFlag, uint32_t &glFlag) + { + uint32_t curRowNum = layoutInput.shape(0); + uint32_t embed = layoutInput.shape(1); + uint32_t embedRound = layoutInput.stride(0); + uint32_t strideQO = layoutOutput.stride(0); + uint32_t tokenNumPerHead = layoutOutput.shape(0); + uint32_t curRowNumAligned64 = (curRowNum + FLOAT_ELENUM_PER_VECCALC - 1) / FLOAT_ELENUM_PER_VECCALC; + uint32_t curRowNumRound = (curRowNum + HALF_ELENUM_PER_BLK - 1) / HALF_ELENUM_PER_BLK * HALF_ELENUM_PER_BLK; + uint64_t dmUbOffsetCurCycle = + (uint64_t)(rescaleOPingPongFlag * HALF_DM_UB_SIZE + rowLoopIdx * ROW_WISE_CYCLE_TILE); + uint64_t llUbOffsetCurCycle = + (uint64_t)(rescaleOPingPongFlag * HALF_LL_UB_SIZE + rowLoopIdx * ROW_WISE_CYCLE_TILE); + uint32_t oUbOffset = oPingPangFlag * ROW_WISE_CYCLE_TILE * embedRound; + AscendC::WaitFlag(oPingPangFlag); + if ((nIdx - 1) != 0) { + AscendC::DataCopy(loUbTensor[oUbOffset], gInput, + AscendC::DataCopyParams(1, curRowNum * embedRound / FLOAT_BLOCK_SIZE, 0, 0)); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + AscendC::WaitFlag(oPingPangFlag + 4); + if ((nIdx - 1) != 0) { + // *** dm = exp(dm) + if (rowLoopIdx == 0) { + AscendC::Exp(dmUbTensor[dmUbOffsetCurCycle], dmUbTensor[dmUbOffsetCurCycle], (uint64_t)0, + curRowNumAligned64, AscendC::UnaryRepeatParams(1, 1, 8, 8)); + AscendC::PipeBarrier(); + AscendC::Mul(glUbTensor, dmUbTensor[dmUbOffsetCurCycle], glUbTensor, (uint64_t)0, + curRowNumAligned64, AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + AscendC::Add(glUbTensor, glUbTensor, llUbTensor[llUbOffsetCurCycle], (uint64_t)0, + curRowNumAligned64, AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + } + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + AscendC::Brcb(tvUbTensor.ReinterpretCast(), + dmUbTensor[dmUbOffsetCurCycle].ReinterpretCast(), curRowNumRound / FLOAT_BLOCK_SIZE, + AscendC::BrcbRepeatParams(1, 8)); + AscendC::PipeBarrier(); + if (needRowLoop) { + AscendC::DataCopy(goUbTensor32[oUbOffset], gUpdate, + AscendC::DataCopyParams(1, curRowNum * embedRound / FLOAT_BLOCK_SIZE, 0, 0)); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + // *** go = go * dm_block + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + for (uint32_t mulIdx = 0; mulIdx < embed / FLOAT_VECTOR_SIZE; ++mulIdx) { + AscendC::Mul(goUbTensor32[oUbOffset + mulIdx * FLOAT_VECTOR_SIZE], + goUbTensor32[oUbOffset + mulIdx * FLOAT_VECTOR_SIZE], tvUbTensor, + (uint64_t)0, curRowNum, + AscendC::BinaryRepeatParams(1, 1, 0, embedRound / FLOAT_BLOCK_SIZE, + embedRound / FLOAT_BLOCK_SIZE, 1)); + } + if (embed % FLOAT_VECTOR_SIZE > 0) { + SetMask(embed % FLOAT_VECTOR_SIZE); + AscendC::Mul(goUbTensor32[oUbOffset + embed / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + goUbTensor32[oUbOffset + embed / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + tvUbTensor, (uint64_t)0, curRowNum, + AscendC::BinaryRepeatParams(1, 1, 0, embedRound / FLOAT_BLOCK_SIZE, + embedRound / FLOAT_BLOCK_SIZE, 1)); + + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + AscendC::PipeBarrier(); + // *** go = lo + go + AscendC::Add(goUbTensor32[oUbOffset], goUbTensor32[oUbOffset], loUbTensor[oUbOffset], + (uint64_t)0, + (curRowNum * embedRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + + AscendC::PipeBarrier(); + } else { + // *** gl = ll + if (rowLoopIdx == 0) { + AscendC::DataCopy(glUbTensor, llUbTensor[llUbOffsetCurCycle], + AscendC::DataCopyParams(1, 64 / FLOAT_BLOCK_SIZE, 0, 0)); + AscendC::PipeBarrier(); + } + AscendC::DataCopy(goUbTensor32[oUbOffset], gInput, + AscendC::DataCopyParams(1, curRowNum * embedRound / FLOAT_BLOCK_SIZE, 0, 0)); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + AscendC::SetFlag(oPingPangFlag); + + if (isLastNTile) { + AscendC::Brcb(tvUbTensor.ReinterpretCast(), + glUbTensor.ReinterpretCast()[rowLoopIdx * ROW_WISE_CYCLE_TILE], + curRowNumRound / FLOAT_BLOCK_SIZE, AscendC::BrcbRepeatParams(1, 8)); + AscendC::PipeBarrier(); + // *** go = go / gl_block + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + for (uint32_t divIdx = 0; divIdx < embed / FLOAT_VECTOR_SIZE; ++divIdx) { + AscendC::Div(goUbTensor32[oUbOffset + divIdx * FLOAT_VECTOR_SIZE], + goUbTensor32[oUbOffset + divIdx * FLOAT_VECTOR_SIZE], tvUbTensor, + (uint64_t)0, curRowNum, + AscendC::BinaryRepeatParams(1, 1, 0, embedRound / FLOAT_BLOCK_SIZE, + embedRound / FLOAT_BLOCK_SIZE, 1)); + } + + if (embed % FLOAT_VECTOR_SIZE > 0) { + SetMask(embed % FLOAT_VECTOR_SIZE); + AscendC::Div(goUbTensor32[oUbOffset + embed / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + goUbTensor32[oUbOffset + embed / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + tvUbTensor, (uint64_t)0, curRowNum, + AscendC::BinaryRepeatParams(1, 1, 0, embedRound / FLOAT_BLOCK_SIZE, + embedRound / FLOAT_BLOCK_SIZE, 1)); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); // fix hidden_size=96 + } + AscendC::PipeBarrier(); + + if (kvSplitCoreNum != 1) { + // log(l) + AscendC::Ln(tvUbTensor, tvUbTensor, (uint64_t)0, curRowNum, + AscendC::UnaryRepeatParams(1, 1, 8, 8)); + AscendC::PipeBarrier(); + AscendC::Brcb(hmUbTensor.ReinterpretCast(), + gmUbTensor.ReinterpretCast()[rowLoopIdx * ROW_WISE_CYCLE_TILE], + curRowNumRound / FLOAT_BLOCK_SIZE, AscendC::BrcbRepeatParams(1, 8)); + AscendC::PipeBarrier(); + // logf(lse_sum) + lse_max + AscendC::Add(tvUbTensor, tvUbTensor, hmUbTensor, (uint64_t)0, curRowNum, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + + AscendC::SetFlag(EVENT_ID2); + AscendC::WaitFlag(EVENT_ID2); + AscendC::DataCopyPad(gl, tvUbTensor, + AscendC::DataCopyExtParams(curRowNum, 4, 0, (kvSplitCoreNum - 1) * 4, 0)); + + if (glFlag == 0) { + AscendC::SetFlag(EVENT_ID2); + glFlag = 1; + } + uint32_t srcGap = ((embed % 16 <= 8) && (embed % 16 > 0)) ? 1 : 0; + AscendC::DataCopyPad( + gOCoreTmp, goUbTensor32[oUbOffset], + AscendC::DataCopyExtParams(curRowNum, embed * 4, srcGap, (kvSplitCoreNum - 1) * embed * 4, 0)); + AscendC::SetFlag(EVENT_ID3); + AscendC::WaitFlag(EVENT_ID3); + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + } else { + // *** go = castfp32to16(go) + if (std::is_same::value) { + AscendC::Cast( + goUbTensor16[oUbOffset * 2], goUbTensor32[oUbOffset], AscendC::RoundMode::CAST_RINT, + (uint64_t)0, (curRowNum * embedRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::UnaryRepeatParams(1, 1, 4, 8)); + } else { + AscendC::Cast( + goUbTensor16[oUbOffset * 2], goUbTensor32[oUbOffset], AscendC::RoundMode::CAST_NONE, + (uint64_t)0, (curRowNum * embedRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::UnaryRepeatParams(1, 1, 4, 8)); + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + // ********************* move O to GM ************************ + if (tokenNumPerHead == 1) { + AscendC::DataCopyPad(gOutput, goUbTensor16[oUbOffset * 2], + AscendC::DataCopyExtParams(curRowNum, embed * 2, 0, 0, 0)); + } else { + uint32_t innerOGmOffset = 0; + uint32_t inner_go_ubuf_offset = oUbOffset * 2; + if (proTokenNum != 0) { + AscendC::DataCopyPad( + gOutput[innerOGmOffset + proTokenIdx * strideQO], goUbTensor16[inner_go_ubuf_offset], + AscendC::DataCopyExtParams(proTokenNum, embed * 2, 0, (strideQO - embed) * 2, 0)); + innerOGmOffset += embed; + inner_go_ubuf_offset += proTokenNum * embed; + } + for (uint32_t qN_idx = 0; qN_idx < integralHeadNum; qN_idx++) { + AscendC::DataCopyPad( + gOutput[innerOGmOffset], goUbTensor16[inner_go_ubuf_offset], + AscendC::DataCopyExtParams(tokenNumPerHead, embed * 2, 0, (strideQO - embed) * 2, 0)); + innerOGmOffset += embed; + inner_go_ubuf_offset += tokenNumPerHead * embed; + } + if (epiTokenNum != 0) { + AscendC::DataCopyPad( + gOutput[innerOGmOffset], goUbTensor16[inner_go_ubuf_offset], + AscendC::DataCopyExtParams(epiTokenNum, embed * 2, 0, (strideQO - embed) * 2, 0)); + } + } + } + } else if (needRowLoop) { + AscendC::SetFlag(EVENT_ID5); + AscendC::WaitFlag(EVENT_ID5); + AscendC::DataCopy(gUpdate, goUbTensor32[oUbOffset], + AscendC::DataCopyParams(1, curRowNum * embedRound / FLOAT_BLOCK_SIZE, 0, 0)); + } + AscendC::SetFlag(oPingPangFlag + 4); + if (needRowLoop) { + oPingPangFlag = 1 - oPingPangFlag; + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gInput, AscendC::GlobalTensor gUpdate, + AscendC::GlobalTensor gOutput, AscendC::GlobalTensor gOCoreTmp, + AscendC::GlobalTensor gl, const LayoutInput &layoutInput, + const LayoutOutput &layoutOutput, const LayoutUpdate &layoutUpdate, GemmCoord actualBlockShape, + uint32_t nIdx, uint32_t isLastNTile, uint32_t curHeadNum, uint32_t rescaleOPingPongFlag, + uint32_t &glFlag) + { + uint32_t tokenNumPerHead = layoutOutput.shape(0); + uint32_t embed = layoutInput.shape(1); + uint32_t rowActual = actualBlockShape.m(); // curHeadNum * tokenNumPerHead + uint32_t columnActual = actualBlockShape.n(); // embed + + uint32_t subBlockIdx = AscendC::GetSubBlockIdx(); + uint32_t subBlockNum = AscendC::GetSubBlockNum(); + + uint32_t curHeadSplitSubBlock = curHeadNum / subBlockNum; + uint32_t curHeadThisSubBlock = (subBlockIdx == 0) ? curHeadSplitSubBlock : (curHeadNum - curHeadSplitSubBlock); + + uint32_t rowActualThisSubBlock = curHeadThisSubBlock * tokenNumPerHead; + uint32_t rowOffsetSubBlock = subBlockIdx * curHeadSplitSubBlock * tokenNumPerHead; + uint32_t outOffsetSubBlock = subBlockIdx * curHeadSplitSubBlock * embed; + + if (rowActualThisSubBlock > 0) { + uint32_t rowLoop = (rowActualThisSubBlock + ROW_WISE_CYCLE_TILE - 1) / ROW_WISE_CYCLE_TILE; + uint32_t needRowLoop = (rowLoop > 1) ? 1 : 0; + // The rows of each cycle consist of multiple heads with several tokens. + // There are several integral heads, one prologue head, one epilogue head. + uint32_t proTokenIdx = 0; // the token idx of the start token of the prologue part + uint32_t proTokenNum = 0; // the token num of the prologue part + uint32_t epiTokenNum = 0; // the token num of the epilogue part + uint32_t integralHeadNum = 0; // the number of integral heads within a cycle + for (uint32_t rowLoopIdx = 0; rowLoopIdx < rowLoop; rowLoopIdx++) { + uint32_t rowOffsetLoop = rowLoopIdx * ROW_WISE_CYCLE_TILE; + uint32_t rowOffsetCurCycle = rowOffsetSubBlock + rowOffsetLoop; + uint32_t rowActualCurCycle = (rowLoopIdx == (rowLoop - 1)) + ? rowActualThisSubBlock - rowLoopIdx * ROW_WISE_CYCLE_TILE + : ROW_WISE_CYCLE_TILE; + int64_t offsetInput = layoutInput.GetOffset(MatrixCoord(rowOffsetCurCycle, 0)); + auto gInputThisCurCycle = gInput[offsetInput]; + auto layoutInputCurCycle = layoutInput.GetTileLayout(MatrixCoord(rowActualCurCycle, columnActual)); + int64_t offsetOutput = rowLoopIdx * ROW_WISE_CYCLE_TILE / tokenNumPerHead * embed + outOffsetSubBlock; + auto gOutputCurCycle = gOutput[offsetOutput]; + auto layoutOutputCurCycle = layoutOutput; + int64_t offsetUpdate = layoutUpdate.GetOffset(MatrixCoord(rowOffsetCurCycle, 0)); + auto gUpdateCurCycle = gUpdate[offsetUpdate]; + auto layoutUpdateCurCycle = layoutUpdate.GetTileLayout(MatrixCoord(rowActualCurCycle, columnActual)); + proTokenIdx = epiTokenNum; + proTokenNum = (tokenNumPerHead - epiTokenNum) % tokenNumPerHead; + integralHeadNum = (rowActualCurCycle - proTokenNum) / tokenNumPerHead; + epiTokenNum = rowActualCurCycle - proTokenNum - integralHeadNum * tokenNumPerHead; + SubCoreCompute(gInputThisCurCycle, gUpdateCurCycle, gOutputCurCycle, + gOCoreTmp[rowOffsetLoop * embed * kvSplitCoreNum], gl[rowOffsetLoop * kvSplitCoreNum], + layoutInputCurCycle, layoutOutputCurCycle, layoutUpdateCurCycle, nIdx, isLastNTile, + needRowLoop, rowLoopIdx, proTokenIdx, proTokenNum, epiTokenNum, integralHeadNum, + rescaleOPingPongFlag, glFlag); + } + } + } + +private: + uint32_t kvSplitCoreNum = 1; + uint32_t oPingPangFlag = 0; + AscendC::LocalTensor goUbTensor16; + AscendC::LocalTensor loUbTensor; + AscendC::LocalTensor dmUbTensor; + AscendC::LocalTensor llUbTensor; + AscendC::LocalTensor glUbTensor; + AscendC::LocalTensor tvUbTensor; + AscendC::LocalTensor goUbTensor32; + AscendC::LocalTensor hmUbTensor; + AscendC::LocalTensor gmUbTensor; +}; + +} // namespace Catlass::Epilogue::Block + +#endif // CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_MLA_RESCALE_O_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_mla_softmax.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_mla_softmax.hpp new file mode 100644 index 00000000..18e8da18 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_mla_softmax.hpp @@ -0,0 +1,359 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_MLA_SOFTMAX_HPP +#define CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_MLA_SOFTMAX_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/epilogue/tile/tile_copy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue +{ +public: + // Type aliases + using DispatchPolicy = EpilogueAtlasA2MLASoftmax; + using ArchTag = typename DispatchPolicy::ArchTag; + using ElementOutput = typename OutputType_::Element; + using ElementInput = typename InputType_::Element; + using ElementMask = typename MaskType_::Element; + + using LayoutOutput = typename OutputType_::Layout; + using LayoutInput = typename InputType_::Layout; + using LayoutMask = typename MaskType_::Layout; + + using CopyGmToUbInput = Tile::CopyGm2Ub; + using CopyGmToUbMask = Tile::CopyGm2Ub; + using CopyUbToGmOutput = Tile::CopyUb2Gm; + + static constexpr uint32_t HALF_ELENUM_PER_BLK = 16; + static constexpr uint32_t FLOAT_BLOCK_SIZE = 8; + static constexpr uint32_t HALF_ELENUM_PER_VECCALC = 128; + static constexpr uint32_t FLOAT_VECTOR_SIZE = 64; + static constexpr uint32_t UB_TILE_SIZE = 16384; // 64 * 128 * 2B + static constexpr uint32_t UB_LINE_SIZE = 512; // 128 * 2 * 2B + static constexpr uint32_t HALF_ELENUM_PER_LINE = 256; // 128 * 2 + static constexpr uint32_t FLOAT_ELENUM_PER_LINE = 128; // 128 + static constexpr uint32_t MULTIPLIER = 2; + static constexpr uint32_t HALF_VECTOR_SIZE = 128; + static constexpr uint32_t BLOCK_SIZE = 16; + static constexpr uint32_t T_BLOCK_SIZE = 32 / 2; + static constexpr uint32_t UB_UINT8_LINE_SIZE = 512; + static constexpr uint32_t UB_UINT8_BLOCK_SIZE_MLA = 16384; + static constexpr uint32_t HALF_DM_UB_SIZE = 128; + static constexpr uint32_t VECTOR_SIZE = 128; + static constexpr uint32_t HALF_LL_UB_SIZE = 256; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource &resource, half tor_, uint32_t kvSplitCoreNum_) + { + // Allocate UB space + constexpr uint32_t LS_UB_TENSOR_OFFSET = 0; + constexpr uint32_t LP_UB_TENSOR_OFFSET = 2 * UB_UINT8_BLOCK_SIZE_MLA; + constexpr uint32_t LM_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA; + constexpr uint32_t HM_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA + 1 * UB_UINT8_LINE_SIZE; + constexpr uint32_t DM_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA + 6 * UB_UINT8_LINE_SIZE; + constexpr uint32_t LL_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA + 9 * UB_UINT8_LINE_SIZE; + constexpr uint32_t GM_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA + 13 * UB_UINT8_LINE_SIZE; + constexpr uint32_t TV_UB_TENSOR_OFFSET = 10 * UB_UINT8_BLOCK_SIZE_MLA; + + tor = tor_; + kvSplitCoreNum = kvSplitCoreNum_; + tvUbTensor16 = resource.ubBuf.template GetBufferByByte(LP_UB_TENSOR_OFFSET); + lpUbTensor32 = resource.ubBuf.template GetBufferByByte(LP_UB_TENSOR_OFFSET); + lsUbTensor = resource.ubBuf.template GetBufferByByte(LS_UB_TENSOR_OFFSET); + lmUbTensor = resource.ubBuf.template GetBufferByByte(LM_UB_TENSOR_OFFSET); + hmUbTensor = resource.ubBuf.template GetBufferByByte(HM_UB_TENSOR_OFFSET); + gmUbTensor = resource.ubBuf.template GetBufferByByte(GM_UB_TENSOR_OFFSET); + dmUbTensor = resource.ubBuf.template GetBufferByByte(DM_UB_TENSOR_OFFSET); + llUbTensor = resource.ubBuf.template GetBufferByByte(LL_UB_TENSOR_OFFSET); + tvUbTensor = resource.ubBuf.template GetBufferByByte(TV_UB_TENSOR_OFFSET); + } + + CATLASS_DEVICE + ~BlockEpilogue() {} + + CATLASS_DEVICE + void SetMask(int32_t len) + { + uint64_t mask = 0; + uint64_t one = 1; + uint64_t temp = len % FLOAT_VECTOR_SIZE; + for (int64_t i = 0; i < temp; i++) { + mask |= one << i; + } + if (len == VECTOR_SIZE) { + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } else if (len >= FLOAT_VECTOR_SIZE) { + AscendC::SetVectorMask(mask, (uint64_t)-1); + } else { + AscendC::SetVectorMask(0x0, mask); + } + } + + CATLASS_DEVICE + void ReduceSumRepeatM(const AscendC::LocalTensor &dst, const AscendC::LocalTensor &src, + uint32_t curRowNum, uint32_t kSeqTile, uint32_t kSeqTileRound) + { + if (kSeqTile <= FLOAT_VECTOR_SIZE) { + SetMask(kSeqTile); + AscendC::RepeatReduceSum(dst, src, curRowNum, 0, 0, 1, 1, kSeqTileRound / FLOAT_BLOCK_SIZE); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } else { + for (uint32_t rowsum_idx = 1; rowsum_idx < kSeqTile / FLOAT_VECTOR_SIZE; ++rowsum_idx) { + AscendC::Add( + src, src, src[rowsum_idx * FLOAT_VECTOR_SIZE], (uint64_t)0, curRowNum, + AscendC::BinaryRepeatParams(1, 1, 1, kSeqTileRound / FLOAT_BLOCK_SIZE, + kSeqTileRound / FLOAT_BLOCK_SIZE, kSeqTileRound / FLOAT_BLOCK_SIZE)); + AscendC::PipeBarrier(); + } + if (kSeqTile % FLOAT_VECTOR_SIZE > 0) { + SetMask(kSeqTile % FLOAT_VECTOR_SIZE); + AscendC::Add( + src, src, src[kSeqTile / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], (uint64_t)0, curRowNum, + AscendC::BinaryRepeatParams(1, 1, 1, kSeqTileRound / FLOAT_BLOCK_SIZE, + kSeqTileRound / FLOAT_BLOCK_SIZE, kSeqTileRound / FLOAT_BLOCK_SIZE)); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + AscendC::PipeBarrier(); + AscendC::RepeatReduceSum(dst, src, curRowNum, 0, 0, 1, 1, kSeqTileRound / FLOAT_BLOCK_SIZE); + } + } + + CATLASS_DEVICE + void TensorSubValueRepeatM(const AscendC::LocalTensor &dst, const AscendC::LocalTensor &src, + const AscendC::LocalTensor &MaxTensor, + const AscendC::LocalTensor &tempMaxTensor, uint32_t curRowNum, uint32_t subMRound, + uint32_t kSeqTile, uint32_t kSeqTileRound) + { + AscendC::Brcb(tempMaxTensor.ReinterpretCast(), MaxTensor.ReinterpretCast(), + subMRound / FLOAT_BLOCK_SIZE, AscendC::BrcbRepeatParams(1, 8)); + AscendC::PipeBarrier(); + for (uint32_t subIdx = 0; subIdx < kSeqTile / FLOAT_VECTOR_SIZE; ++subIdx) { + AscendC::Sub(dst[subIdx * FLOAT_VECTOR_SIZE], src[subIdx * FLOAT_VECTOR_SIZE], tempMaxTensor, + (uint64_t)0, curRowNum, + AscendC::BinaryRepeatParams(1, 1, 0, kSeqTileRound / FLOAT_BLOCK_SIZE, + kSeqTileRound / FLOAT_BLOCK_SIZE, 1)); + } + if (kSeqTile % FLOAT_VECTOR_SIZE > 0) { + SetMask(kSeqTile % FLOAT_VECTOR_SIZE); + AscendC::Sub(dst[kSeqTile / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + src[kSeqTile / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], tempMaxTensor, + (uint64_t)0, curRowNum, + AscendC::BinaryRepeatParams(1, 1, 0, kSeqTileRound / FLOAT_BLOCK_SIZE, + kSeqTileRound / FLOAT_BLOCK_SIZE, 1)); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + AscendC::PipeBarrier(); + } + + CATLASS_DEVICE + void ReduceMaxRepeatM(const AscendC::LocalTensor &dst, const AscendC::LocalTensor &src, + const AscendC::LocalTensor &tempTensor, uint32_t curRowNum, uint32_t kSeqTile, + uint32_t kSeqTileRound) + { + if (kSeqTile <= FLOAT_VECTOR_SIZE) { + SetMask(kSeqTile); + AscendC::WholeReduceMax(dst, src, (int32_t)0, curRowNum, 1, 1, + kSeqTileRound / FLOAT_BLOCK_SIZE, + AscendC::ReduceOrder::ORDER_ONLY_VALUE); + } else { + AscendC::DataCopy(tempTensor, src, + AscendC::DataCopyParams(curRowNum, HALF_VECTOR_SIZE / BLOCK_SIZE, + (kSeqTileRound - FLOAT_VECTOR_SIZE) / FLOAT_BLOCK_SIZE, 0)); + AscendC::PipeBarrier(); + for (uint32_t rowmaxIdx = 1; rowmaxIdx < kSeqTile / FLOAT_VECTOR_SIZE; ++rowmaxIdx) { + AscendC::Max( + tempTensor, tempTensor, src[rowmaxIdx * FLOAT_VECTOR_SIZE], (uint64_t)0, curRowNum, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, kSeqTileRound / FLOAT_BLOCK_SIZE)); + AscendC::PipeBarrier(); + } + if (kSeqTile % FLOAT_VECTOR_SIZE > 0) { + SetMask(kSeqTile % FLOAT_VECTOR_SIZE); + AscendC::Max( + tempTensor, tempTensor, src[kSeqTile / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], (uint64_t)0, + curRowNum, AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, kSeqTileRound / FLOAT_BLOCK_SIZE)); + } + AscendC::PipeBarrier(); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + AscendC::WholeReduceMax(dst, tempTensor, (int32_t)0, curRowNum, 1, 1, 8, + AscendC::ReduceOrder::ORDER_ONLY_VALUE); + } + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + AscendC::PipeBarrier(); + } + + CATLASS_DEVICE + void SubCoreCompute(AscendC::GlobalTensor gOutput, AscendC::GlobalTensor gInput, + const LayoutOutput &layoutOutput, const LayoutInput &layoutInput, uint32_t nIdx, + uint32_t softmaxPingPongFlag, uint32_t &glFlag) + { + uint32_t curRowNum = layoutInput.shape(0); + uint32_t kSeqTile = layoutInput.shape(1); + uint32_t kSeqTileRound = layoutInput.stride(0); + uint32_t subMRound = (curRowNum + 16 - 1) / 16 * 16; + uint32_t sub_m_d64 = (curRowNum + 63) / 64; // up aligned to 128 + uint64_t dmUbOffsetCurCycle = (uint64_t)(softmaxPingPongFlag * HALF_DM_UB_SIZE); + uint64_t llUbOffsetCurCycle = (uint64_t)(softmaxPingPongFlag * HALF_LL_UB_SIZE); + AscendC::WaitFlag(EVENT_ID2); + AscendC::DataCopy(lsUbTensor, gInput, + AscendC::DataCopyParams(1, curRowNum * kSeqTileRound / FLOAT_BLOCK_SIZE, 0, 0)); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + // muls scale_value + for (uint32_t mulsIdx = 0; mulsIdx < kSeqTile / FLOAT_VECTOR_SIZE; ++mulsIdx) { + AscendC::Muls( + lsUbTensor[mulsIdx * FLOAT_VECTOR_SIZE], lsUbTensor[mulsIdx * FLOAT_VECTOR_SIZE], tor, (uint64_t)0, + curRowNum, + AscendC::UnaryRepeatParams(1, 1, kSeqTileRound / FLOAT_BLOCK_SIZE, kSeqTileRound / FLOAT_BLOCK_SIZE)); + } + if (kSeqTile % FLOAT_VECTOR_SIZE > 0) { + SetMask(kSeqTile % FLOAT_VECTOR_SIZE); + AscendC::Muls( + lsUbTensor[kSeqTile / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + lsUbTensor[kSeqTile / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], tor, (uint64_t)0, curRowNum, + AscendC::UnaryRepeatParams(1, 1, kSeqTileRound / FLOAT_BLOCK_SIZE, kSeqTileRound / FLOAT_BLOCK_SIZE)); + + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + AscendC::PipeBarrier(); + + // *** lm = rowmax(ls) + ReduceMaxRepeatM(lmUbTensor, lsUbTensor, lpUbTensor32, curRowNum, kSeqTile, kSeqTileRound); + + if (nIdx != 0) { + AscendC::Max(hmUbTensor, lmUbTensor, gmUbTensor, (uint64_t)0, sub_m_d64, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + AscendC::Sub(dmUbTensor[dmUbOffsetCurCycle], gmUbTensor, hmUbTensor, (uint64_t)0, sub_m_d64, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + } else { + AscendC::DataCopy(hmUbTensor, lmUbTensor, AscendC::DataCopyParams(1, subMRound / FLOAT_BLOCK_SIZE, 0, 0)); + AscendC::PipeBarrier(); + } + // *** gm = hm + AscendC::DataCopy(gmUbTensor, hmUbTensor, AscendC::DataCopyParams(1, subMRound / FLOAT_BLOCK_SIZE, 0, 0)); + AscendC::PipeBarrier(); + + if (kvSplitCoreNum != 1) { + if (nIdx == 0) { + if (glFlag == 1) { + AscendC::WaitFlag(EVENT_ID2); + glFlag = 0; + } + } + } + + // *** ls = ls - hm_block + TensorSubValueRepeatM(lsUbTensor, lsUbTensor, hmUbTensor, tvUbTensor, curRowNum, subMRound, kSeqTile, + kSeqTileRound); + + AscendC::Exp(lsUbTensor, lsUbTensor, (uint64_t)0, + (curRowNum * kSeqTileRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::UnaryRepeatParams(1, 1, 8, 8)); + + AscendC::PipeBarrier(); + // *** lp = castfp32to16(ls) + if (std::is_same::value) { + AscendC::Cast( + tvUbTensor16, lsUbTensor, AscendC::RoundMode::CAST_RINT, (uint64_t)0, + (curRowNum * kSeqTileRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::UnaryRepeatParams(1, 1, 4, 8)); + } else { + AscendC::Cast( + tvUbTensor16, lsUbTensor, AscendC::RoundMode::CAST_NONE, (uint64_t)0, + (curRowNum * kSeqTileRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::UnaryRepeatParams(1, 1, 4, 8)); + } + + AscendC::PipeBarrier(); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + uint16_t blockCount = 1; + uint16_t blockLen = curRowNum * kSeqTileRound / T_BLOCK_SIZE; + uint16_t srcStride = 0; + uint16_t dstStride = 0; + + AscendC::DataCopy(gOutput, tvUbTensor16, + AscendC::DataCopyParams(blockCount, // blockCount + blockLen, // blockLen + srcStride, // srcGap + dstStride)); + + // *** ll = rowsum(ls32) + ReduceSumRepeatM(llUbTensor[llUbOffsetCurCycle], lsUbTensor, curRowNum, kSeqTile, kSeqTileRound); + AscendC::SetFlag(EVENT_ID2); + AscendC::PipeBarrier(); + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gOutput, AscendC::GlobalTensor gInput, + const LayoutOutput &layoutOutput, const LayoutInput &layoutInput, GemmCoord actualBlockShape, + uint32_t nIdx, uint32_t curHeadNum, uint32_t softmaxPingPongFlag, uint32_t &glFlag) + { + uint32_t rowActual = actualBlockShape.m(); + uint32_t nActual = actualBlockShape.n(); + uint32_t tokenNumPerHead = rowActual / curHeadNum; + + uint32_t subBlockIdx = AscendC::GetSubBlockIdx(); + uint32_t subBlockNum = AscendC::GetSubBlockNum(); + + uint32_t curHeadSplitSubBlock = curHeadNum / subBlockNum; + uint32_t curHeadThisSubBlock = (subBlockIdx == 0) ? curHeadSplitSubBlock : (curHeadNum - curHeadSplitSubBlock); + + uint32_t rowActualThisSubBlock = curHeadThisSubBlock * tokenNumPerHead; + uint32_t rowOffsetSubBlock = subBlockIdx * curHeadSplitSubBlock * tokenNumPerHead; + + if (rowActualThisSubBlock > 0) { + int64_t offsetInput = layoutInput.GetOffset(MatrixCoord(rowOffsetSubBlock, 0)); + auto gInputThisSubBlock = gInput[offsetInput]; + auto layoutInputThisSubBlock = layoutInput.GetTileLayout(MatrixCoord(rowActualThisSubBlock, nActual)); + int64_t offsetOutput = layoutOutput.GetOffset(MatrixCoord(rowOffsetSubBlock, 0)); + auto gOutputThisSubBlock = gOutput[offsetOutput]; + auto layoutOutputThisSubBlock = layoutOutput.GetTileLayout(MatrixCoord(rowActualThisSubBlock, nActual)); + SubCoreCompute(gOutputThisSubBlock, gInputThisSubBlock, layoutOutputThisSubBlock, layoutInputThisSubBlock, + nIdx, softmaxPingPongFlag, glFlag); + } + } + +private: + float tor; + uint32_t pingpongFlag = 0; + uint32_t kvSplitCoreNum = 1; + AscendC::LocalTensor tvUbTensor16; + AscendC::LocalTensor lpUbTensor32; + AscendC::LocalTensor lsUbTensor; + AscendC::LocalTensor lmUbTensor; + AscendC::LocalTensor hmUbTensor; + AscendC::LocalTensor gmUbTensor; + AscendC::LocalTensor dmUbTensor; + AscendC::LocalTensor llUbTensor; + AscendC::LocalTensor tvUbTensor; + + CopyGmToUbInput copyGmToUbInput; + CopyGmToUbMask copyGmToUbMask; + CopyUbToGmOutput copyUbToGmOutput; +}; + +} // namespace Catlass::Epilogue::Block + +#endif // CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_MLA_SOFTMAX_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_mla_tp1_rescale_o.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_mla_tp1_rescale_o.hpp new file mode 100644 index 00000000..6e2e03bc --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_mla_tp1_rescale_o.hpp @@ -0,0 +1,324 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_MLA_TP1_RESCALE_O_HPP +#define CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_MLA_TP1_RESCALE_O_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/epilogue/tile/tile_copy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue +{ +public: + // Type aliases + using DispatchPolicy = EpilogueAtlasA2MLATP1RescaleO; + using ArchTag = typename DispatchPolicy::ArchTag; + + using ElementOutput = typename OutputType_::Element; + using ElementUpdate = typename UpdateType_::Element; + using ElementInput = typename InputType_::Element; + + using LayoutOutput = typename OutputType_::Layout; + using LayoutUpdate = typename UpdateType_::Layout; + using LayoutInput = typename InputType_::Layout; + + static constexpr uint32_t HALF_ELENUM_PER_BLK = 16; + static constexpr uint32_t HALF_ELENUM_PER_VECCALC = 128; + static constexpr uint32_t FLOAT_ELENUM_PER_VECCALC = 64; + static constexpr uint32_t HALF_ELENUM_PER_LINE = 256; + static constexpr uint32_t FLOAT_ELENUM_PER_LINE = 128; + static constexpr uint32_t MULTIPLIER = 2; + static constexpr uint32_t FLOAT_BLOCK_SIZE = 8; + static constexpr uint32_t FLOAT_VECTOR_SIZE = 64; + static constexpr uint32_t UB_UINT8_LINE_SIZE = 512; + static constexpr uint32_t UB_UINT8_BLOCK_SIZE_MLA = 16384; + static constexpr uint32_t ROW_WISE_CYCLE_TILE = 8; + static constexpr uint32_t HALF_DM_UB_SIZE = 64; + static constexpr uint32_t HALF_LL_UB_SIZE = 256; + static constexpr uint32_t VECTOR_SIZE = 128; + static constexpr uint32_t NUM4 = 4; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource &resource, uint32_t kvSplitCoreNum_ = 1) + { + // Allocate UB space + constexpr uint32_t LO_UB_TENSOR_OFFSET = 4 * UB_UINT8_BLOCK_SIZE_MLA; + constexpr uint32_t DM_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA + 6 * UB_UINT8_LINE_SIZE; + constexpr uint32_t GL_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA + 16 * UB_UINT8_LINE_SIZE; + constexpr uint32_t GO_UB_TENSOR_OFFSET = 8 * UB_UINT8_BLOCK_SIZE_MLA; + constexpr uint32_t TV_UB_TENSOR_OFFSET = 10 * UB_UINT8_BLOCK_SIZE_MLA; + constexpr uint32_t HM_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA + 1 * UB_UINT8_LINE_SIZE; + constexpr uint32_t GM_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA + 14 * UB_UINT8_LINE_SIZE; + + kvSplitCoreNum = kvSplitCoreNum_; + loUbTensor = resource.ubBuf.template GetBufferByByte(LO_UB_TENSOR_OFFSET); + dmUbTensor = resource.ubBuf.template GetBufferByByte(DM_UB_TENSOR_OFFSET); + glUbTensor = resource.ubBuf.template GetBufferByByte(GL_UB_TENSOR_OFFSET); + tvUbTensor = resource.ubBuf.template GetBufferByByte(TV_UB_TENSOR_OFFSET); + goUbTensor16 = resource.ubBuf.template GetBufferByByte(GO_UB_TENSOR_OFFSET); + goUbTensor32 = resource.ubBuf.template GetBufferByByte(GO_UB_TENSOR_OFFSET); + hmUbTensor = resource.ubBuf.template GetBufferByByte(HM_UB_TENSOR_OFFSET); + gmUbTensor = resource.ubBuf.template GetBufferByByte(GM_UB_TENSOR_OFFSET); + } + + CATLASS_DEVICE + ~BlockEpilogue() {} + + CATLASS_DEVICE + void SetMask(int32_t len) + { + uint64_t mask = 0; + uint64_t one = 1; + uint64_t temp = len % FLOAT_VECTOR_SIZE; + for (int64_t i = 0; i < temp; i++) { + mask |= one << i; + } + + if (len == VECTOR_SIZE) { + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } else if (len >= FLOAT_VECTOR_SIZE) { + AscendC::SetVectorMask(mask, (uint64_t)-1); + } else { + AscendC::SetVectorMask(0x0, mask); + } + } + + CATLASS_DEVICE + void SetkvSplitCoreNum(uint32_t kvSplitCoreNum_) + { + kvSplitCoreNum = kvSplitCoreNum_; + } + + CATLASS_DEVICE + void SubCoreCompute(AscendC::GlobalTensor gInput, AscendC::GlobalTensor gUpdate, + AscendC::GlobalTensor gOutput, AscendC::GlobalTensor gOCoreTmp, + AscendC::GlobalTensor gl, const LayoutInput &layoutInput, + const LayoutOutput &layoutOutput, const LayoutUpdate &layoutUpdate, uint32_t nIdx, + uint32_t isLastNTile, uint32_t needRowLoop, uint32_t rowLoopIdx, uint32_t rescaleOPingPongFlag, + uint32_t &glFlag) + { + uint32_t curRowNum = layoutInput.shape(0); + uint32_t embed = layoutInput.shape(1); + uint32_t embedRound = layoutInput.stride(0); + uint32_t curRowNumRound = RoundUp(curRowNum); + uint64_t dmUbOffsetCurCycle = + (uint64_t)(rescaleOPingPongFlag * HALF_DM_UB_SIZE + rowLoopIdx * ROW_WISE_CYCLE_TILE); + uint32_t oUbOffset = oPingPangFlag * ROW_WISE_CYCLE_TILE * embedRound; + AscendC::WaitFlag(oPingPangFlag); + if (nIdx != NUM4) { + AscendC::DataCopy(loUbTensor[oUbOffset], gInput, + AscendC::DataCopyParams(1, curRowNum * embedRound / FLOAT_BLOCK_SIZE, 0, 0)); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + AscendC::WaitFlag(oPingPangFlag + 4); + if (nIdx != NUM4) { + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + AscendC::Brcb(tvUbTensor.ReinterpretCast(), + dmUbTensor[dmUbOffsetCurCycle].ReinterpretCast(), curRowNumRound / FLOAT_BLOCK_SIZE, + AscendC::BrcbRepeatParams(1, 8)); + AscendC::PipeBarrier(); + if (needRowLoop) { + AscendC::DataCopy(goUbTensor32[oUbOffset], gUpdate, + AscendC::DataCopyParams(1, curRowNum * embedRound / FLOAT_BLOCK_SIZE, 0, 0)); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + // *** go = go * dm_block + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + for (uint32_t vmul_idx = 0; vmul_idx < embed / FLOAT_VECTOR_SIZE; ++vmul_idx) { + AscendC::Mul(goUbTensor32[oUbOffset + vmul_idx * FLOAT_VECTOR_SIZE], + goUbTensor32[oUbOffset + vmul_idx * FLOAT_VECTOR_SIZE], tvUbTensor, + (uint64_t)0, curRowNum, + AscendC::BinaryRepeatParams(1, 1, 0, embedRound / FLOAT_BLOCK_SIZE, + embedRound / FLOAT_BLOCK_SIZE, 1)); + } + if (embed % FLOAT_VECTOR_SIZE > 0) { + SetMask(embed % FLOAT_VECTOR_SIZE); + AscendC::Mul(goUbTensor32[oUbOffset + embed / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + goUbTensor32[oUbOffset + embed / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + tvUbTensor, (uint64_t)0, curRowNum, + AscendC::BinaryRepeatParams(1, 1, 0, embedRound / FLOAT_BLOCK_SIZE, + embedRound / FLOAT_BLOCK_SIZE, 1)); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + AscendC::PipeBarrier(); + // *** go = lo + go + AscendC::Add(goUbTensor32[oUbOffset], goUbTensor32[oUbOffset], loUbTensor[oUbOffset], + (uint64_t)0, + (curRowNum * embedRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + } else { + // *** gl = ll + AscendC::DataCopy(goUbTensor32[oUbOffset], gInput, + AscendC::DataCopyParams(1, curRowNum * embedRound / FLOAT_BLOCK_SIZE, 0, 0)); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + AscendC::SetFlag(oPingPangFlag); + + if (isLastNTile) { + // *** gl_block = expand_to_block(gl), 存放于 tv + AscendC::Brcb(tvUbTensor.ReinterpretCast(), + glUbTensor.ReinterpretCast()[rowLoopIdx * ROW_WISE_CYCLE_TILE], + curRowNumRound / FLOAT_BLOCK_SIZE, AscendC::BrcbRepeatParams(1, 8)); + AscendC::PipeBarrier(); + // *** go = go / gl_block + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + for (uint32_t vdiv_idx = 0; vdiv_idx < embed / FLOAT_VECTOR_SIZE; ++vdiv_idx) { + AscendC::Div(goUbTensor32[oUbOffset + vdiv_idx * FLOAT_VECTOR_SIZE], + goUbTensor32[oUbOffset + vdiv_idx * FLOAT_VECTOR_SIZE], tvUbTensor, + (uint64_t)0, curRowNum, + AscendC::BinaryRepeatParams(1, 1, 0, embedRound / FLOAT_BLOCK_SIZE, + embedRound / FLOAT_BLOCK_SIZE, 1)); + } + if (embed % FLOAT_VECTOR_SIZE > 0) { + SetMask(embed % FLOAT_VECTOR_SIZE); + AscendC::Div(goUbTensor32[oUbOffset + embed / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + goUbTensor32[oUbOffset + embed / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + tvUbTensor, (uint64_t)0, curRowNum, + AscendC::BinaryRepeatParams(1, 1, 0, embedRound / FLOAT_BLOCK_SIZE, + embedRound / FLOAT_BLOCK_SIZE, 1)); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + AscendC::PipeBarrier(); + + if (kvSplitCoreNum != 1) { + // log(l) + AscendC::Ln(tvUbTensor, tvUbTensor, (uint64_t)0, curRowNum, + AscendC::UnaryRepeatParams(1, 1, 8, 8)); + AscendC::PipeBarrier(); + AscendC::Brcb(hmUbTensor.ReinterpretCast(), + gmUbTensor.ReinterpretCast()[rowLoopIdx * ROW_WISE_CYCLE_TILE], + curRowNumRound / FLOAT_BLOCK_SIZE, AscendC::BrcbRepeatParams(1, 8)); + AscendC::PipeBarrier(); + // logf(lse_sum) + lse_max + AscendC::Add(tvUbTensor, tvUbTensor, hmUbTensor, (uint64_t)0, curRowNum, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + + AscendC::SetFlag(EVENT_ID2); + AscendC::WaitFlag(EVENT_ID2); + AscendC::DataCopyPad(gl, tvUbTensor, + AscendC::DataCopyExtParams(curRowNum, 4, 0, (kvSplitCoreNum - 1) * 4, 0)); + + if (glFlag == 0) { + AscendC::SetFlag(EVENT_ID2); + glFlag = 1; + } + uint32_t srcGap = ((embed % 16 <= 8) && (embed % 16 > 0)) ? 1 : 0; + AscendC::DataCopyPad( + gOCoreTmp, goUbTensor32[oUbOffset], + AscendC::DataCopyExtParams(curRowNum, embed * 4, srcGap, (kvSplitCoreNum - 1) * embed * 4, 0)); + AscendC::SetFlag(EVENT_ID3); + AscendC::WaitFlag(EVENT_ID3); + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + } else { + // *** go = castfp32to16(go) + if (std::is_same::value) { + AscendC::Cast( + goUbTensor16[oUbOffset * 2], goUbTensor32[oUbOffset], AscendC::RoundMode::CAST_RINT, + (uint64_t)0, (curRowNum * embedRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::UnaryRepeatParams(1, 1, 4, 8)); + } else { + AscendC::Cast( + goUbTensor16[oUbOffset * 2], goUbTensor32[oUbOffset], AscendC::RoundMode::CAST_NONE, + (uint64_t)0, (curRowNum * embedRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::UnaryRepeatParams(1, 1, 4, 8)); + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + // ********************* move O to GM ************************ + AscendC::DataCopyPad(gOutput, goUbTensor16[oUbOffset * 2], + AscendC::DataCopyExtParams(curRowNum, embed * 2, 0, 0, 0)); + } + } else if (needRowLoop) { + AscendC::SetFlag(EVENT_ID5); + AscendC::WaitFlag(EVENT_ID5); + AscendC::DataCopy(gUpdate, goUbTensor32[oUbOffset], + AscendC::DataCopyParams(1, curRowNum * embedRound / FLOAT_BLOCK_SIZE, 0, 0)); + } + AscendC::SetFlag(oPingPangFlag + 4); + oPingPangFlag = 1 - oPingPangFlag; + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gInput, AscendC::GlobalTensor gUpdate, + AscendC::GlobalTensor gOutput, AscendC::GlobalTensor gOCoreTmp, + AscendC::GlobalTensor gl, const LayoutInput &layoutInput, + const LayoutUpdate &layoutUpdate, const LayoutOutput &layoutOutput, GemmCoord actualBlockShape, + uint32_t nIdx, uint32_t isLastNTile, uint32_t rescaleOPingPongFlag, uint32_t &glFlag) + { + uint32_t embed = layoutInput.shape(1); + uint32_t rowActual = actualBlockShape.m(); + uint32_t columnActual = actualBlockShape.n(); + + uint32_t subBlockIdx = AscendC::GetSubBlockIdx(); + uint32_t subBlockNum = AscendC::GetSubBlockNum(); + + uint32_t curRowSplitSubBlock = rowActual / subBlockNum; + uint32_t rowActualThisSubBlock = (subBlockIdx == 0) ? curRowSplitSubBlock : (rowActual - curRowSplitSubBlock); + uint32_t rowOffsetSubBlock = subBlockIdx * curRowSplitSubBlock; + + if (rowActualThisSubBlock > 0) { + uint32_t rowLoop = (rowActualThisSubBlock + ROW_WISE_CYCLE_TILE - 1) / ROW_WISE_CYCLE_TILE; + uint32_t needRowLoop = (rowLoop > 1) ? 1 : 0; + for (uint32_t rowLoopIdx = 0; rowLoopIdx < rowLoop; rowLoopIdx++) { + uint32_t rowOffsetLoop = rowLoopIdx * ROW_WISE_CYCLE_TILE; + uint32_t rowOffsetCurCycle = rowOffsetSubBlock + rowOffsetLoop; + uint32_t rowActualCurCycle = + (rowLoopIdx == (rowLoop - 1)) ? rowActualThisSubBlock - rowOffsetLoop : ROW_WISE_CYCLE_TILE; + int64_t offsetInput = layoutInput.GetOffset(MatrixCoord(rowOffsetCurCycle, 0)); + auto gInputThisCurCycle = gInput[offsetInput]; + auto layoutInputCurCycle = layoutInput.GetTileLayout(MatrixCoord(rowActualCurCycle, columnActual)); + + int64_t offsetUpdate = layoutUpdate.GetOffset(MatrixCoord(rowOffsetCurCycle, 0)); + auto gUpdateCurCycle = gUpdate[offsetUpdate]; + auto layoutUpdateCurCycle = layoutUpdate.GetTileLayout(MatrixCoord(rowActualCurCycle, columnActual)); + + int64_t offsetOutput = layoutOutput.GetOffset(MatrixCoord(rowOffsetCurCycle, 0)); + auto gOutputCurCycle = gOutput[offsetOutput]; + auto layoutOutputCurCycle = layoutOutput.GetTileLayout(MatrixCoord(rowActualCurCycle, columnActual)); + + SubCoreCompute(gInputThisCurCycle, gUpdateCurCycle, gOutputCurCycle, + gOCoreTmp[rowOffsetLoop * embed * kvSplitCoreNum], gl[rowOffsetLoop * kvSplitCoreNum], + layoutInputCurCycle, layoutOutputCurCycle, layoutUpdateCurCycle, nIdx, isLastNTile, + needRowLoop, rowLoopIdx, rescaleOPingPongFlag, glFlag); + } + } + } + +private: + uint32_t kvSplitCoreNum = 1; + uint32_t oPingPangFlag = 0; + AscendC::LocalTensor loUbTensor; + AscendC::LocalTensor dmUbTensor; + AscendC::LocalTensor glUbTensor; + AscendC::LocalTensor tvUbTensor; + AscendC::LocalTensor goUbTensor16; + AscendC::LocalTensor goUbTensor32; + AscendC::LocalTensor hmUbTensor; + AscendC::LocalTensor gmUbTensor; +}; + +} // namespace Catlass::Epilogue::Block + +#endif // CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_MLA_TP1_RESCALE_O_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_mla_tp1_softmax.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_mla_tp1_softmax.hpp new file mode 100644 index 00000000..fd70350d --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_mla_tp1_softmax.hpp @@ -0,0 +1,501 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_MLA_TP1_SOFTMAX_HPP +#define CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_MLA_TP1_SOFTMAX_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/epilogue/tile/tile_copy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue +{ +public: + // Type aliases + using DispatchPolicy = EpilogueAtlasA2MLATP1Softmax; + using ArchTag = typename DispatchPolicy::ArchTag; + using ElementOutput = typename OutputType_::Element; + using ElementInput = typename InputType_::Element; + using ElementMask = typename MaskType_::Element; + + using LayoutOutput = typename OutputType_::Layout; + using LayoutInput = typename InputType_::Layout; + using LayoutMask = typename MaskType_::Layout; + + using CopyGmToUbInput = Tile::CopyGm2Ub; + using CopyGmToUbMask = Tile::CopyGm2Ub; + using CopyUbToGmOutput = Tile::CopyUb2Gm; + + static constexpr uint32_t FLOAT_BLOCK_SIZE = 8; + static constexpr uint32_t FLOAT_VECTOR_SIZE = 64; + static constexpr uint32_t HALF_VECTOR_SIZE = 128; + static constexpr uint32_t BLOCK_SIZE = 16; + static constexpr uint32_t UB_UINT8_LINE_SIZE = 512; + static constexpr uint32_t UB_UINT8_BLOCK_SIZE_MLA = 16384; + static constexpr uint32_t VECTOR_SIZE = 128; + + static constexpr uint32_t REDUCE_UB_SIZE = 1024; + static constexpr uint32_t ROW_OPS_SPEC_MASK_32 = 32; + static constexpr uint32_t ROW_OPS_SPEC_MASK_4 = 4; + static constexpr uint32_t S_BLOCK_STACK = 4; + static constexpr int64_t UB_FLOAT_LINE_SIZE = 64; + static constexpr uint32_t M_SLICE = 16; + static constexpr uint32_t QK_READY_ID = 1; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource &resource, half tor_, uint32_t kvSplitCoreNum_ = 1) + { + // Allocate UB space + constexpr uint32_t LS_UB_TENSOR_OFFSET = 0; + constexpr uint32_t LP_UB_TENSOR_OFFSET = 0; + constexpr uint32_t LM_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA; + constexpr uint32_t HM_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA + 1 * UB_UINT8_LINE_SIZE; + constexpr uint32_t DM_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA + 6 * UB_UINT8_LINE_SIZE; + constexpr uint32_t LL_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA + 10 * UB_UINT8_LINE_SIZE; + constexpr uint32_t GM_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA + 14 * UB_UINT8_LINE_SIZE; + constexpr uint32_t GL_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE_MLA + 16 * UB_UINT8_LINE_SIZE; + constexpr uint32_t TV_UB_TENSOR_OFFSET = 10 * UB_UINT8_BLOCK_SIZE_MLA; + + tor = tor_; + kvSplitCoreNum = kvSplitCoreNum_; + lsUbTensor = resource.ubBuf.template GetBufferByByte(LS_UB_TENSOR_OFFSET); + lpUbTensor = resource.ubBuf.template GetBufferByByte(LP_UB_TENSOR_OFFSET); + lmUbTensor = resource.ubBuf.template GetBufferByByte(LM_UB_TENSOR_OFFSET); + hmUbTensor = resource.ubBuf.template GetBufferByByte(HM_UB_TENSOR_OFFSET); + gmUbTensor = resource.ubBuf.template GetBufferByByte(GM_UB_TENSOR_OFFSET); + dmUbTensor = resource.ubBuf.template GetBufferByByte(DM_UB_TENSOR_OFFSET); + llUbTensor = resource.ubBuf.template GetBufferByByte(LL_UB_TENSOR_OFFSET); + tvUbTensor = resource.ubBuf.template GetBufferByByte(TV_UB_TENSOR_OFFSET); + glUbTensor = resource.ubBuf.template GetBufferByByte(GL_UB_TENSOR_OFFSET); + } + + CATLASS_DEVICE + ~BlockEpilogue() {} + + CATLASS_DEVICE + void SetVecMask(int32_t len) + { + uint64_t mask = 0; + uint64_t one = 1; + uint64_t temp = len % FLOAT_VECTOR_SIZE; + for (int64_t i = 0; i < temp; i++) { + mask |= one << i; + } + + if (len == VECTOR_SIZE || len == 0) { + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } else if (len >= FLOAT_VECTOR_SIZE) { + AscendC::SetVectorMask(mask, (uint64_t)-1); + } else { + AscendC::SetVectorMask(0x0, mask); + } + } + + CATLASS_DEVICE + void SetkvSplitCoreNum(uint32_t kvSplitCoreNum_) + { + kvSplitCoreNum = kvSplitCoreNum_; + } + + CATLASS_DEVICE + void SetBlockReduceMask(int32_t len) + { + if (len > 8 || len < 1) { + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + return; + } + uint64_t subMask = ((uint64_t)1 << len) - 1; + uint64_t maskValue = (subMask << 48) + (subMask << 32) + (subMask << 16) + subMask + (subMask << 56) + + (subMask << 40) + (subMask << 24) + (subMask << 8); + AscendC::SetVectorMask(maskValue, maskValue); + } + + CATLASS_DEVICE + void RowsumSPECTILE512(const AscendC::LocalTensor &srcUb, const AscendC::LocalTensor &rowsumUb, + const AscendC::LocalTensor &tvUbTensor, uint32_t numRowsRound, uint32_t numElems, + uint32_t numElemsAligned) + { + AscendC::BlockReduceSum(tvUbTensor, srcUb, numRowsRound * numElemsAligned / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + + AscendC::BlockReduceSum(tvUbTensor[REDUCE_UB_SIZE], tvUbTensor, + numRowsRound * numElemsAligned / FLOAT_BLOCK_SIZE / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + AscendC::BlockReduceSum(rowsumUb, tvUbTensor[REDUCE_UB_SIZE], + numRowsRound * numElemsAligned / FLOAT_VECTOR_SIZE / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + } + + CATLASS_DEVICE + void RowsumSPECTILE256(const AscendC::LocalTensor &srcUb, const AscendC::LocalTensor &rowsumUb, + const AscendC::LocalTensor &tvUbTensor, uint32_t numRowsRound, uint32_t numElems, + uint32_t numElemsAligned) + { + AscendC::BlockReduceSum(tvUbTensor, srcUb, numRowsRound * numElemsAligned / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + SetVecMask(ROW_OPS_SPEC_MASK_32); + AscendC::BlockReduceSum(tvUbTensor[REDUCE_UB_SIZE], tvUbTensor, numRowsRound, 0, 1, 1, 4); + AscendC::PipeBarrier(); + SetBlockReduceMask(ROW_OPS_SPEC_MASK_4); + AscendC::BlockReduceSum( + rowsumUb, tvUbTensor[REDUCE_UB_SIZE], + (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 0, 1, 1, 8); + AscendC::PipeBarrier(); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + + CATLASS_DEVICE + void RowsumTAILTILE(const AscendC::LocalTensor &srcUb, const AscendC::LocalTensor &rowsumUb, + const AscendC::LocalTensor &tvUbTensor, uint32_t numRowsRound, uint32_t numElems, + uint32_t numElemsAligned) + { + if (numElems >= FLOAT_VECTOR_SIZE) { + AscendC::BlockReduceSum(tvUbTensor, srcUb, numRowsRound, 0, 1, 1, + numElemsAligned / FLOAT_BLOCK_SIZE); + AscendC::PipeBarrier(); + AscendC::BlockReduceSum( + rowsumUb, tvUbTensor, (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + for (uint64_t rowSumIdx = 1; rowSumIdx < (uint64_t)numElems / FLOAT_VECTOR_SIZE; ++rowSumIdx) { + AscendC::BlockReduceSum(tvUbTensor, srcUb[rowSumIdx * FLOAT_VECTOR_SIZE], numRowsRound, 0, + 1, 1, numElemsAligned / FLOAT_BLOCK_SIZE); + AscendC::PipeBarrier(); + AscendC::BlockReduceSum( + tvUbTensor[REDUCE_UB_SIZE], tvUbTensor, + (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 0, 1, 1, 8); + AscendC::PipeBarrier(); + SetVecMask(numRowsRound); + AscendC::Add(rowsumUb, rowsumUb, tvUbTensor[REDUCE_UB_SIZE], (uint64_t)0, 1, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + } + if (numElems % FLOAT_VECTOR_SIZE > 0) { + SetVecMask(numElems % FLOAT_VECTOR_SIZE); + AscendC::BlockReduceSum(tvUbTensor, srcUb[numElems / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + numRowsRound, 0, 1, 1, numElemsAligned / FLOAT_BLOCK_SIZE); + AscendC::PipeBarrier(); + SetBlockReduceMask((numElems % FLOAT_VECTOR_SIZE + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE); + if (numElems < FLOAT_VECTOR_SIZE) { + AscendC::BlockReduceSum( + rowsumUb, tvUbTensor, (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + 0, 1, 1, 8); + AscendC::PipeBarrier(); + } else { + AscendC::BlockReduceSum( + tvUbTensor[REDUCE_UB_SIZE], tvUbTensor, + (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 0, 1, 1, 8); + AscendC::PipeBarrier(); + SetVecMask(numRowsRound); + AscendC::Add(rowsumUb, rowsumUb, tvUbTensor[REDUCE_UB_SIZE], (uint64_t)0, 1, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + } + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + } + + CATLASS_DEVICE + void RowmaxSPECTILE512(const AscendC::LocalTensor &srcUb, const AscendC::LocalTensor &rowmaxUb, + const AscendC::LocalTensor &tvUbTensor, uint32_t numRowsRound, uint32_t numElems, + uint32_t numElemsAligned) + { + AscendC::BlockReduceMax(tvUbTensor, srcUb, numRowsRound * numElemsAligned / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + AscendC::BlockReduceMax(tvUbTensor[REDUCE_UB_SIZE], tvUbTensor, + numRowsRound * numElemsAligned / FLOAT_BLOCK_SIZE / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + AscendC::BlockReduceMax(rowmaxUb, tvUbTensor[REDUCE_UB_SIZE], + numRowsRound * numElemsAligned / FLOAT_VECTOR_SIZE / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + } + + CATLASS_DEVICE + void RowmaxSPECTILE256(const AscendC::LocalTensor &srcUb, const AscendC::LocalTensor &rowmaxUb, + const AscendC::LocalTensor &tvUbTensor, uint32_t numRowsRound, uint32_t numElems, + uint32_t numElemsAligned) + { + AscendC::BlockReduceMax(tvUbTensor, srcUb, numRowsRound * numElemsAligned / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + SetVecMask(ROW_OPS_SPEC_MASK_32); + AscendC::BlockReduceMax(tvUbTensor[REDUCE_UB_SIZE], tvUbTensor, numRowsRound, 0, 1, 1, 4); + AscendC::PipeBarrier(); + SetBlockReduceMask(ROW_OPS_SPEC_MASK_4); + AscendC::BlockReduceMax( + rowmaxUb, tvUbTensor[REDUCE_UB_SIZE], + (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 0, 1, 1, 8); + AscendC::PipeBarrier(); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + + CATLASS_DEVICE + void RowmaxTAILTILE(const AscendC::LocalTensor &srcUb, const AscendC::LocalTensor &rowmaxUb, + const AscendC::LocalTensor &tvUbTensor, uint32_t numRowsRound, uint32_t numElems, + uint32_t numElemsAligned) + { + if (numElems >= FLOAT_VECTOR_SIZE) { + AscendC::BlockReduceMax(tvUbTensor, srcUb, numRowsRound, 0, 1, 1, + numElemsAligned / FLOAT_BLOCK_SIZE); + AscendC::PipeBarrier(); + AscendC::BlockReduceMax( + rowmaxUb, tvUbTensor, (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + for (uint64_t rowmax_idx = 1; rowmax_idx < (uint64_t)numElems / FLOAT_VECTOR_SIZE; ++rowmax_idx) { + AscendC::BlockReduceMax(tvUbTensor, srcUb[rowmax_idx * FLOAT_VECTOR_SIZE], numRowsRound, + 0, 1, 1, numElemsAligned / FLOAT_BLOCK_SIZE); + AscendC::PipeBarrier(); + AscendC::BlockReduceMax( + tvUbTensor[REDUCE_UB_SIZE], tvUbTensor, + (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 0, 1, 1, 8); + AscendC::PipeBarrier(); + SetVecMask(numRowsRound); + AscendC::Max(rowmaxUb, rowmaxUb, tvUbTensor[REDUCE_UB_SIZE], (uint64_t)0, 1, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + } + if (numElems % FLOAT_VECTOR_SIZE > 0) { + SetVecMask(numElems % FLOAT_VECTOR_SIZE); + AscendC::BlockReduceMax(tvUbTensor, srcUb[numElems / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + numRowsRound, 0, 1, 1, numElemsAligned / FLOAT_BLOCK_SIZE); + AscendC::PipeBarrier(); + SetBlockReduceMask((numElems % FLOAT_VECTOR_SIZE + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE); + if (numElems < FLOAT_VECTOR_SIZE) { + AscendC::BlockReduceMax( + rowmaxUb, tvUbTensor, (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + 0, 1, 1, 8); + AscendC::PipeBarrier(); + } else { + AscendC::BlockReduceMax( + tvUbTensor[REDUCE_UB_SIZE], tvUbTensor, + (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 0, 1, 1, 8); + AscendC::PipeBarrier(); + SetVecMask(numRowsRound); + AscendC::Max(rowmaxUb, rowmaxUb, tvUbTensor[REDUCE_UB_SIZE], (uint64_t)0, 1, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + } + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + } + + CATLASS_DEVICE + void SubCoreCompute(AscendC::GlobalTensor gOutput, AscendC::GlobalTensor gInput, + uint32_t m, uint32_t nReal, uint32_t nStride, uint32_t pingpongFlag, uint32_t rowOffset, + uint32_t sUbOffset, uint32_t nIdx, uint32_t &glFlag) + { + uint32_t round_m = (m + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE * FLOAT_BLOCK_SIZE; + AscendC::WaitFlag(pingpongFlag); + // input QK + AscendC::DataCopy(lsUbTensor[sUbOffset], gInput, AscendC::DataCopyParams(m, nStride / FLOAT_BLOCK_SIZE, 0, 0)); + + AscendC::SetFlag(pingpongFlag); + AscendC::WaitFlag(pingpongFlag); + + // *** ls = tor * ls + AscendC::Muls(lsUbTensor[sUbOffset], lsUbTensor[sUbOffset], tor, (uint64_t)0, + (m * nStride + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::UnaryRepeatParams(1, 1, 8, 8)); + + AscendC::PipeBarrier(); + + if (kvSplitCoreNum != 1) { + if (nIdx == 0) { + if (glFlag == 1) { + AscendC::WaitFlag(EVENT_ID2); + glFlag = 0; + } + } + } + + if (nReal == 512) { + RowmaxSPECTILE512(lsUbTensor[sUbOffset], lmUbTensor[rowOffset], tvUbTensor, round_m, nReal, nStride); + } else if (nReal == 256) { + RowmaxSPECTILE256(lsUbTensor[sUbOffset], lmUbTensor[rowOffset], tvUbTensor, round_m, nReal, nStride); + } else { + RowmaxTAILTILE(lsUbTensor[sUbOffset], lmUbTensor[rowOffset], tvUbTensor, round_m, nReal, nStride); + } + + if (nIdx == 0) { + AscendC::DataCopy(hmUbTensor[rowOffset], lmUbTensor[rowOffset], + AscendC::DataCopyParams(1, round_m / FLOAT_BLOCK_SIZE, 0, 0)); + AscendC::PipeBarrier(); + } else { + SetVecMask(m); + // *** hm = vmax(lm, gm) + AscendC::Max(hmUbTensor[rowOffset], lmUbTensor[rowOffset], gmUbTensor[rowOffset], (uint64_t)0, + 1, AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + + AscendC::PipeBarrier(); + // *** dm = gm - hm + AscendC::Sub(dmUbTensor[((nIdx / S_BLOCK_STACK) % 2) * UB_FLOAT_LINE_SIZE + rowOffset], + gmUbTensor[rowOffset], hmUbTensor[rowOffset], (uint64_t)0, 1, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + + AscendC::PipeBarrier(); + // *** dm = exp(dm) + AscendC::Exp(dmUbTensor[((nIdx / S_BLOCK_STACK) % 2) * UB_FLOAT_LINE_SIZE + rowOffset], + dmUbTensor[((nIdx / S_BLOCK_STACK) % 2) * UB_FLOAT_LINE_SIZE + rowOffset], + (uint64_t)0, 1, AscendC::UnaryRepeatParams(1, 1, 8, 8)); + } + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + AscendC::PipeBarrier(); + // *** gm = hm + AscendC::DataCopy(gmUbTensor[rowOffset], hmUbTensor[rowOffset], + AscendC::DataCopyParams(1, round_m / FLOAT_BLOCK_SIZE, 0, 0)); + AscendC::PipeBarrier(); + // *** hm_block = expand_to_block(hm), 存放于 tv + AscendC::Brcb(tvUbTensor.template ReinterpretCast(), + hmUbTensor[rowOffset].template ReinterpretCast(), round_m / FLOAT_BLOCK_SIZE, + AscendC::BrcbRepeatParams(1, 8)); + AscendC::PipeBarrier(); + // *** ls = ls - hm_block + for (uint32_t subIdx = 0; subIdx < nReal / FLOAT_VECTOR_SIZE; ++subIdx) { + AscendC::Sub( + lsUbTensor[sUbOffset][subIdx * FLOAT_VECTOR_SIZE], lsUbTensor[sUbOffset][subIdx * FLOAT_VECTOR_SIZE], + tvUbTensor, (uint64_t)0, m, + AscendC::BinaryRepeatParams(1, 1, 0, nStride / FLOAT_BLOCK_SIZE, nStride / FLOAT_BLOCK_SIZE, 1)); + } + if (nReal % FLOAT_VECTOR_SIZE > 0) { + SetVecMask(nReal % FLOAT_VECTOR_SIZE); + AscendC::Sub( + lsUbTensor[sUbOffset][nReal / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + lsUbTensor[sUbOffset][nReal / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], tvUbTensor, (uint64_t)0, m, + AscendC::BinaryRepeatParams(1, 1, 0, nStride / FLOAT_BLOCK_SIZE, nStride / FLOAT_BLOCK_SIZE, 1)); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + AscendC::PipeBarrier(); + + // *** ls = exp(ls) + AscendC::Exp(lsUbTensor[sUbOffset], lsUbTensor[sUbOffset], (uint64_t)0, + (m * nStride + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::UnaryRepeatParams(1, 1, 8, 8)); + AscendC::PipeBarrier(); + // *** ll = rowsum(ls32) + if (nReal == 512) { + RowsumSPECTILE512(lsUbTensor[sUbOffset], llUbTensor[rowOffset], tvUbTensor, round_m, nReal, nStride); + } else if (nReal == 256) { + RowsumSPECTILE256(lsUbTensor[sUbOffset], llUbTensor[rowOffset], tvUbTensor, round_m, nReal, nStride); + } else { + RowsumTAILTILE(lsUbTensor[sUbOffset], llUbTensor[rowOffset], tvUbTensor, round_m, nReal, nStride); + } + + // *** lp = castfp32to16(ls) + if (std::is_same::value) { + AscendC::Cast( + lpUbTensor[sUbOffset * 2], lsUbTensor[sUbOffset], AscendC::RoundMode::CAST_RINT, (uint64_t)0, + (m * nStride + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, AscendC::UnaryRepeatParams(1, 1, 4, 8)); + } else { + AscendC::Cast( + lpUbTensor[sUbOffset * 2], lsUbTensor[sUbOffset], AscendC::RoundMode::CAST_NONE, (uint64_t)0, + (m * nStride + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, AscendC::UnaryRepeatParams(1, 1, 4, 8)); + } + + AscendC::SetFlag(pingpongFlag); + AscendC::WaitFlag(pingpongFlag); + AscendC::DataCopy(gOutput, lpUbTensor[sUbOffset * 2], AscendC::DataCopyParams(m, nStride * 2 / 32, 0, 0)); + AscendC::SetFlag(pingpongFlag); + if (nIdx == 0) { + // *** gl = ll + AscendC::DataCopy(glUbTensor[rowOffset], llUbTensor[rowOffset], + AscendC::DataCopyParams(1, round_m / FLOAT_BLOCK_SIZE, 0, 0)); + AscendC::PipeBarrier(); + } else { + SetVecMask(m); + // *** gl = dm * gl + AscendC::Mul( + glUbTensor[rowOffset], dmUbTensor[((nIdx / S_BLOCK_STACK) % 2) * UB_FLOAT_LINE_SIZE + rowOffset], + glUbTensor[rowOffset], (uint64_t)0, 1, AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + // *** gl = ll + gl + AscendC::Add(glUbTensor[rowOffset], glUbTensor[rowOffset], llUbTensor[rowOffset], (uint64_t)0, + 1, AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gOutput, AscendC::GlobalTensor gInput, + const LayoutOutput &layoutOutput, const LayoutInput &layoutInput, GemmCoord actualBlockShape, + uint32_t nIdx, uint32_t &glFlag) + { + uint32_t cur_head_num = actualBlockShape.m(); + uint32_t qkN = actualBlockShape.n(); + uint32_t qkRoundN = layoutInput.stride(0); + uint32_t pingpongFlag = 0; + + uint32_t subBlockIdx = AscendC::GetSubBlockIdx(); + uint32_t subBlockNum = AscendC::GetSubBlockNum(); + + uint32_t subM = (subBlockIdx == 1) ? (cur_head_num - cur_head_num / 2) : cur_head_num / 2; + + uint32_t mEnd = (subM + M_SLICE - 1) / M_SLICE; + + for (uint32_t mInd = 0; mInd < mEnd; mInd++) { + uint32_t rowOffset = mInd * M_SLICE; + uint32_t currM = mInd == mEnd - 1 ? subM - rowOffset : M_SLICE; + uint32_t sUbOffset = pingpongFlag * 8192; + int64_t offsetOutput = rowOffset * qkRoundN; + auto gOutputThisSubBlock = gOutput[offsetOutput]; + int64_t offsetInput = rowOffset * qkRoundN; + auto gInputThisSubBlock = gInput[offsetInput]; + if (mInd == 0) { + Arch::CrossCoreWaitFlag(qkReady); + } + if (currM == 0) { + continue; + } + SubCoreCompute(gOutputThisSubBlock, gInputThisSubBlock, currM, qkN, qkRoundN, pingpongFlag, rowOffset, + sUbOffset, nIdx, glFlag); + pingpongFlag = 1 - pingpongFlag; + } + } + +private: + float tor; + uint32_t pingpongFlag = 0; + uint32_t kvSplitCoreNum = 1; + + AscendC::LocalTensor lsUbTensor; + AscendC::LocalTensor lpUbTensor; + AscendC::LocalTensor lmUbTensor; + AscendC::LocalTensor hmUbTensor; + AscendC::LocalTensor gmUbTensor; + AscendC::LocalTensor dmUbTensor; + AscendC::LocalTensor llUbTensor; + AscendC::LocalTensor tvUbTensor; + AscendC::LocalTensor glUbTensor; + + Arch::CrossCoreFlag qkReady{QK_READY_ID}; +}; + +} // namespace Catlass::Epilogue::Block + +#endif // CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_MLA_TP1_SOFTMAX_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_online_softmax_no_mask.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_online_softmax_no_mask.hpp new file mode 100644 index 00000000..d18062a4 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_online_softmax_no_mask.hpp @@ -0,0 +1,768 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_ONLINE_SOFTMAX_NO_MASK_HPP +#define CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_ONLINE_SOFTMAX_NO_MASK_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/epilogue/tile/tile_copy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue +{ +public: + using DispatchPolicy = EpilogueAtlasA2OnlineSoftmax; + using ArchTag = typename DispatchPolicy::ArchTag; + using ElementOutput = typename OutputType_::Element; + using ElementInput = typename InputType_::Element; + using ElementMask = typename MaskType_::Element; + + using LayoutOutput = typename OutputType_::Layout; + using LayoutInput = typename InputType_::Layout; + using LayoutMask = typename MaskType_::Layout; + + static constexpr uint32_t FLOAT_BLOCK_SIZE = 8; + static constexpr uint32_t FLOAT_VECTOR_SIZE = 64; + static constexpr uint32_t HALF_VECTOR_SIZE = 128; + static constexpr uint32_t BLOCK_SIZE = 16; + static constexpr uint32_t UB_UINT8_VECTOR_SIZE = 1024; + static constexpr uint32_t UB_UINT8_BLOCK_SIZE = 16384; + static constexpr uint32_t VECTOR_SIZE = 128; + static constexpr uint32_t MAX_UB_S_ELEM_NUM = 8192; + + static constexpr uint32_t REDUCE_UB_SIZE = 1024; + static constexpr uint32_t ROW_OPS_SPEC_MASK_32 = 32; + static constexpr uint32_t ROW_OPS_SPEC_MASK_4 = 4; + static constexpr uint32_t MAX_ROW_NUM_SUB_CORE = 128; + static constexpr int64_t UB_FLOAT_LINE_SIZE = 64; + enum class MaskCategory { NO_MASK = 0, CAUSAL_MASK = 1 }; + CATLASS_DEVICE + BlockEpilogue(Arch::Resource &resource, float scaleValue_) + { + // Allocate UB space + constexpr uint32_t LS_UB_TENSOR_OFFSET = 0; + constexpr uint32_t LP_UB_TENSOR_OFFSET = 4 * UB_UINT8_BLOCK_SIZE; + constexpr uint32_t MASK32_UB_TENSOR_OFFSET = 4 * UB_UINT8_BLOCK_SIZE; + + constexpr uint32_t TV_UB_TENSOR_OFFSET = 10 * UB_UINT8_BLOCK_SIZE; + constexpr uint32_t LM_UB_TENSOR_OFFSET = 10 * UB_UINT8_BLOCK_SIZE + 8 * UB_UINT8_VECTOR_SIZE; + + constexpr uint32_t HM_UB_TENSOR_OFFSET = 10 * UB_UINT8_BLOCK_SIZE + 9 * UB_UINT8_VECTOR_SIZE; + constexpr uint32_t GM_UB_TENSOR_OFFSET = 10 * UB_UINT8_BLOCK_SIZE + 10 * UB_UINT8_VECTOR_SIZE; + constexpr uint32_t LL_UB_TENSOR_OFFSET = 10 * UB_UINT8_BLOCK_SIZE + 11 * UB_UINT8_VECTOR_SIZE; + constexpr uint32_t GL_UB_TENSOR_OFFSET = 10 * UB_UINT8_BLOCK_SIZE + 12 * UB_UINT8_VECTOR_SIZE; + constexpr uint32_t DM_UB_TENSOR_OFFSET = 10 * UB_UINT8_BLOCK_SIZE + 13 * UB_UINT8_VECTOR_SIZE; + + constexpr uint32_t MASK_UB_TENSOR_OFFSET = 11 * UB_UINT8_BLOCK_SIZE; + + scaleValue = scaleValue_; + lsUbTensor = resource.ubBuf.template GetBufferByByte(LS_UB_TENSOR_OFFSET); + lpUbTensor = resource.ubBuf.template GetBufferByByte(LP_UB_TENSOR_OFFSET); + maskUbTensor = resource.ubBuf.template GetBufferByByte(MASK_UB_TENSOR_OFFSET); + maskUbTensor32 = resource.ubBuf.template GetBufferByByte(MASK32_UB_TENSOR_OFFSET); + lmUbTensor = resource.ubBuf.template GetBufferByByte(LM_UB_TENSOR_OFFSET); + hmUbTensor = resource.ubBuf.template GetBufferByByte(HM_UB_TENSOR_OFFSET); + gmUbTensor = resource.ubBuf.template GetBufferByByte(GM_UB_TENSOR_OFFSET); + dmUbTensor = resource.ubBuf.template GetBufferByByte(DM_UB_TENSOR_OFFSET); + llUbTensor = resource.ubBuf.template GetBufferByByte(LL_UB_TENSOR_OFFSET); + tvUbTensor = resource.ubBuf.template GetBufferByByte(TV_UB_TENSOR_OFFSET); + glUbTensor = resource.ubBuf.template GetBufferByByte(GL_UB_TENSOR_OFFSET); + } + + CATLASS_DEVICE + ~BlockEpilogue() {} + + template + CATLASS_DEVICE T Min(T a, T b) + { + return (a > b) ? b : a; + } + + CATLASS_DEVICE + void SetVecMask(int32_t len) + { + uint64_t mask = 0; + uint64_t one = 1; + uint64_t temp = len % FLOAT_VECTOR_SIZE; + for (int64_t i = 0; i < temp; i++) { + mask |= one << i; + } + + if (len == VECTOR_SIZE || len == 0) { + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } else if (len >= FLOAT_VECTOR_SIZE) { + AscendC::SetVectorMask(mask, (uint64_t)-1); + } else { + AscendC::SetVectorMask(0x0, mask); + } + } + + CATLASS_DEVICE + void SetBlockReduceMask(int32_t len) + { + if (len > 8 || len < 1) { + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + return; + } + uint64_t subMask = ((uint64_t)1 << len) - 1; + uint64_t maskValue = (subMask << 48) + (subMask << 32) + (subMask << 16) + subMask + (subMask << 56) + + (subMask << 40) + (subMask << 24) + (subMask << 8); + AscendC::SetVectorMask(maskValue, maskValue); + } + + CATLASS_DEVICE + void RowsumSPECTILE512(const AscendC::LocalTensor &srcUb, const AscendC::LocalTensor &rowsumUb, + const AscendC::LocalTensor &tvUbTensor, uint32_t numRowsRound, uint32_t numElems, + uint32_t numElemsAligned) + { + AscendC::BlockReduceSum(tvUbTensor, srcUb, numRowsRound * numElemsAligned / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + + AscendC::BlockReduceSum(tvUbTensor[REDUCE_UB_SIZE], tvUbTensor, + numRowsRound * numElemsAligned / FLOAT_BLOCK_SIZE / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + AscendC::BlockReduceSum(rowsumUb, tvUbTensor[REDUCE_UB_SIZE], + numRowsRound * numElemsAligned / FLOAT_VECTOR_SIZE / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + } + + CATLASS_DEVICE + void RowsumSPECTILE256(const AscendC::LocalTensor &srcUb, const AscendC::LocalTensor &rowsumUb, + const AscendC::LocalTensor &tvUbTensor, uint32_t numRowsRound, uint32_t numElems, + uint32_t numElemsAligned) + { + AscendC::BlockReduceSum(tvUbTensor, srcUb, numRowsRound * numElemsAligned / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + SetVecMask(ROW_OPS_SPEC_MASK_32); + AscendC::BlockReduceSum(tvUbTensor[REDUCE_UB_SIZE], tvUbTensor, numRowsRound, 0, 1, 1, 4); + AscendC::PipeBarrier(); + SetBlockReduceMask(ROW_OPS_SPEC_MASK_4); + AscendC::BlockReduceSum( + rowsumUb, tvUbTensor[REDUCE_UB_SIZE], + (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 0, 1, 1, 8); + AscendC::PipeBarrier(); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + + CATLASS_DEVICE + void RowsumTAILTILE(const AscendC::LocalTensor &srcUb, const AscendC::LocalTensor &rowsumUb, + const AscendC::LocalTensor &tvUbTensor, uint32_t numRowsRound, uint32_t numElems, + uint32_t numElemsAligned) + { + if (numElems >= FLOAT_VECTOR_SIZE) { + AscendC::BlockReduceSum(tvUbTensor, srcUb, numRowsRound, 0, 1, 1, + numElemsAligned / FLOAT_BLOCK_SIZE); + AscendC::PipeBarrier(); + AscendC::BlockReduceSum( + rowsumUb, tvUbTensor, (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + for (uint64_t rowSumIdx = 1; rowSumIdx < (uint64_t)numElems / FLOAT_VECTOR_SIZE; ++rowSumIdx) { + AscendC::BlockReduceSum(tvUbTensor, srcUb[rowSumIdx * FLOAT_VECTOR_SIZE], numRowsRound, 0, + 1, 1, numElemsAligned / FLOAT_BLOCK_SIZE); + AscendC::PipeBarrier(); + AscendC::BlockReduceSum( + tvUbTensor[REDUCE_UB_SIZE], tvUbTensor, + (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 0, 1, 1, 8); + AscendC::PipeBarrier(); + SetVecMask(numRowsRound); + AscendC::Add(rowsumUb, rowsumUb, tvUbTensor[REDUCE_UB_SIZE], (uint64_t)0, 1, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + } + if (numElems % FLOAT_VECTOR_SIZE > 0) { + SetVecMask(numElems % FLOAT_VECTOR_SIZE); + AscendC::BlockReduceSum(tvUbTensor, srcUb[numElems / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + numRowsRound, 0, 1, 1, numElemsAligned / FLOAT_BLOCK_SIZE); + AscendC::PipeBarrier(); + SetBlockReduceMask((numElems % FLOAT_VECTOR_SIZE + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE); + if (numElems < FLOAT_VECTOR_SIZE) { + AscendC::BlockReduceSum( + rowsumUb, tvUbTensor, (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + 0, 1, 1, 8); + AscendC::PipeBarrier(); + } else { + AscendC::BlockReduceSum( + tvUbTensor[REDUCE_UB_SIZE], tvUbTensor, + (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 0, 1, 1, 8); + AscendC::PipeBarrier(); + SetVecMask(numRowsRound); + AscendC::Add(rowsumUb, rowsumUb, tvUbTensor[REDUCE_UB_SIZE], (uint64_t)0, 1, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + } + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + } + + CATLASS_DEVICE + void RowmaxSPECTILE512(const AscendC::LocalTensor &srcUb, const AscendC::LocalTensor &rowmaxUb, + const AscendC::LocalTensor &tvUbTensor, uint32_t numRowsRound, uint32_t numElems, + uint32_t numElemsAligned) + { + AscendC::BlockReduceMax(tvUbTensor, srcUb, numRowsRound * numElemsAligned / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + AscendC::BlockReduceMax(tvUbTensor[REDUCE_UB_SIZE], tvUbTensor, + numRowsRound * numElemsAligned / FLOAT_BLOCK_SIZE / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + AscendC::BlockReduceMax(rowmaxUb, tvUbTensor[REDUCE_UB_SIZE], + numRowsRound * numElemsAligned / FLOAT_VECTOR_SIZE / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + } + + CATLASS_DEVICE + void RowmaxSPECTILE256(const AscendC::LocalTensor &srcUb, const AscendC::LocalTensor &rowmaxUb, + const AscendC::LocalTensor &tvUbTensor, uint32_t numRowsRound, uint32_t numElems, + uint32_t numElemsAligned) + { + AscendC::BlockReduceMax(tvUbTensor, srcUb, numRowsRound * numElemsAligned / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + SetVecMask(ROW_OPS_SPEC_MASK_32); + AscendC::BlockReduceMax(tvUbTensor[REDUCE_UB_SIZE], tvUbTensor, numRowsRound, 0, 1, 1, 4); + AscendC::PipeBarrier(); + SetBlockReduceMask(ROW_OPS_SPEC_MASK_4); + AscendC::BlockReduceMax( + rowmaxUb, tvUbTensor[REDUCE_UB_SIZE], + (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 0, 1, 1, 8); + AscendC::PipeBarrier(); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + + CATLASS_DEVICE + void RowmaxTAILTILE(const AscendC::LocalTensor &srcUb, const AscendC::LocalTensor &rowmaxUb, + const AscendC::LocalTensor &tvUbTensor, uint32_t numRowsRound, uint32_t numElems, + uint32_t numElemsAligned) + { + if (numElems >= FLOAT_VECTOR_SIZE) { + AscendC::BlockReduceMax(tvUbTensor, srcUb, numRowsRound, 0, 1, 1, + numElemsAligned / FLOAT_BLOCK_SIZE); + AscendC::PipeBarrier(); + AscendC::BlockReduceMax( + rowmaxUb, tvUbTensor, (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 0, + 1, 1, 8); + AscendC::PipeBarrier(); + for (uint64_t rowmax_idx = 1; rowmax_idx < (uint64_t)numElems / FLOAT_VECTOR_SIZE; ++rowmax_idx) { + AscendC::BlockReduceMax(tvUbTensor, srcUb[rowmax_idx * FLOAT_VECTOR_SIZE], numRowsRound, + 0, 1, 1, numElemsAligned / FLOAT_BLOCK_SIZE); + AscendC::PipeBarrier(); + AscendC::BlockReduceMax( + tvUbTensor[REDUCE_UB_SIZE], tvUbTensor, + (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 0, 1, 1, 8); + AscendC::PipeBarrier(); + SetVecMask(numRowsRound); + AscendC::Max(rowmaxUb, rowmaxUb, tvUbTensor[REDUCE_UB_SIZE], (uint64_t)0, 1, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + } + if (numElems % FLOAT_VECTOR_SIZE > 0) { + SetVecMask(numElems % FLOAT_VECTOR_SIZE); + AscendC::BlockReduceMax(tvUbTensor, srcUb[numElems / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + numRowsRound, 0, 1, 1, numElemsAligned / FLOAT_BLOCK_SIZE); + AscendC::PipeBarrier(); + SetBlockReduceMask((numElems % FLOAT_VECTOR_SIZE + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE); + if (numElems < FLOAT_VECTOR_SIZE) { + AscendC::BlockReduceMax( + rowmaxUb, tvUbTensor, (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + 0, 1, 1, 8); + AscendC::PipeBarrier(); + } else { + AscendC::BlockReduceMax( + tvUbTensor[REDUCE_UB_SIZE], tvUbTensor, + (numRowsRound * FLOAT_BLOCK_SIZE + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, 0, 1, 1, 8); + AscendC::PipeBarrier(); + SetVecMask(numRowsRound); + AscendC::Max(rowmaxUb, rowmaxUb, tvUbTensor[REDUCE_UB_SIZE], (uint64_t)0, 1, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + } + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + } + + CATLASS_DEVICE + void CopySGmToUb(AscendC::GlobalTensor gInput, uint32_t sUbOffset, uint32_t rowNumCurLoop, + uint32_t columnNumRound, uint32_t columnNumPad) + { + // input S + AscendC::DataCopy(lsUbTensor[sUbOffset], gInput, + AscendC::DataCopyParams(rowNumCurLoop, columnNumRound / FLOAT_BLOCK_SIZE, + (columnNumPad - columnNumRound) / FLOAT_BLOCK_SIZE, 0)); + } + + CATLASS_DEVICE + void CopyMaskGmToUb(AscendC::GlobalTensor gMask, uint32_t columnNum, uint32_t columnNumRound, + uint32_t maskStride, uint32_t qSBlockSize, uint32_t proTokenIdx, uint32_t proTokenNum, + uint32_t integralHeadNum, uint32_t epiTokenNum) + { + uint32_t innerUbRowOffset = 0; + if (proTokenNum != 0) { + AscendC::DataCopyPad( + maskUbTensor[innerUbRowOffset], gMask[proTokenIdx * maskStride], + AscendC::DataCopyExtParams(proTokenNum, columnNum * 2, (maskStride - columnNum) * 2, 0, 0), + AscendC::DataCopyPadExtParams(false, 0, 0, 0)); + innerUbRowOffset += proTokenNum * columnNumRound; + } + for (uint32_t headIdx = 0; headIdx < integralHeadNum; headIdx++) { + AscendC::DataCopyPad( + maskUbTensor[innerUbRowOffset], gMask, + AscendC::DataCopyExtParams(qSBlockSize, columnNum * 2, (maskStride - columnNum) * 2, 0, 0), + AscendC::DataCopyPadExtParams(false, 0, 0, 0)); + innerUbRowOffset += qSBlockSize * columnNumRound; + } + if (epiTokenNum != 0) { + AscendC::DataCopyPad( + maskUbTensor[innerUbRowOffset], gMask, + AscendC::DataCopyExtParams(epiTokenNum, columnNum * 2, (maskStride - columnNum) * 2, 0, 0), + AscendC::DataCopyPadExtParams(false, 0, 0, 0)); + } + } + + CATLASS_DEVICE + void ScaleS(uint32_t sUbOffset, uint32_t rowNumCurLoop, uint32_t columnNumRound) + { + // *** ls = scaleValue * ls + AscendC::Muls(lsUbTensor[sUbOffset], lsUbTensor[sUbOffset], scaleValue, (uint64_t)0, + (rowNumCurLoop * columnNumRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::UnaryRepeatParams(1, 1, 8, 8)); + + AscendC::PipeBarrier(); + } + + CATLASS_DEVICE + void UpCastMask(uint32_t rowNumCurLoop, uint32_t columnNumRound) + { + AscendC::Cast( + maskUbTensor32, maskUbTensor, AscendC::RoundMode::CAST_NONE, (uint64_t)0, + (rowNumCurLoop * columnNumRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::UnaryRepeatParams(1, 1, 8, 4)); + AscendC::PipeBarrier(); + } + + CATLASS_DEVICE + void ApplyMask(uint32_t sUbOffset, uint32_t rowNumCurLoop, uint32_t columnNumRound) + { + AscendC::Muls(maskUbTensor32, maskUbTensor32, (float)-3e38, (uint64_t)0, + (rowNumCurLoop * columnNumRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::UnaryRepeatParams(1, 1, 8, 8)); + AscendC::PipeBarrier(); + AscendC::Add(lsUbTensor[sUbOffset], lsUbTensor[sUbOffset], maskUbTensor32, (uint64_t)0, + (rowNumCurLoop * columnNumRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + } + + CATLASS_DEVICE + void CalcLocalRowMax(uint32_t sUbOffset, uint32_t rowNumCurLoopRound, uint32_t columnNum, uint32_t columnNumRound, + uint32_t rowOffset) + { + if (columnNum == 512) { + RowmaxSPECTILE512(lsUbTensor[sUbOffset], lmUbTensor[rowOffset], tvUbTensor, rowNumCurLoopRound, columnNum, + columnNumRound); + } else if (columnNum == 256) { + RowmaxSPECTILE256(lsUbTensor[sUbOffset], lmUbTensor[rowOffset], tvUbTensor, rowNumCurLoopRound, columnNum, + columnNumRound); + } else { + RowmaxTAILTILE(lsUbTensor[sUbOffset], lmUbTensor[rowOffset], tvUbTensor, rowNumCurLoopRound, columnNum, + columnNumRound); + } + } + + CATLASS_DEVICE + void UpdateGlobalRowMax(uint32_t rowNumCurLoop, uint32_t rowNumCurLoopRound, uint32_t columnNum, + uint32_t columnNumRound, uint32_t dmUbOffsetCurCycle, uint32_t rowOffset, + uint32_t isFirstStackTile) + { + if (isFirstStackTile) { + AscendC::DataCopy(hmUbTensor[rowOffset], lmUbTensor[rowOffset], + AscendC::DataCopyParams(1, rowNumCurLoopRound / FLOAT_BLOCK_SIZE, 0, 0)); + AscendC::PipeBarrier(); + } else { + SetVecMask(rowNumCurLoop); + // *** hm = vmax(lm, gm) + AscendC::Max(hmUbTensor[rowOffset], lmUbTensor[rowOffset], gmUbTensor[rowOffset], (uint64_t)0, + 1, AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + + AscendC::PipeBarrier(); + // *** dm = gm - hm + AscendC::Sub(dmUbTensor[dmUbOffsetCurCycle], gmUbTensor[rowOffset], hmUbTensor[rowOffset], + (uint64_t)0, 1, AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + + AscendC::PipeBarrier(); + // *** dm = exp(dm) + AscendC::Exp(dmUbTensor[dmUbOffsetCurCycle], dmUbTensor[dmUbOffsetCurCycle], (uint64_t)0, 1, + AscendC::UnaryRepeatParams(1, 1, 8, 8)); + } + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + AscendC::PipeBarrier(); + // *** gm = hm + AscendC::DataCopy(gmUbTensor[rowOffset], hmUbTensor[rowOffset], + AscendC::DataCopyParams(1, rowNumCurLoopRound / FLOAT_BLOCK_SIZE, 0, 0)); + AscendC::PipeBarrier(); + } + + CATLASS_DEVICE + void CalcExp(uint32_t sUbOffset, uint32_t rowNumCurLoop, uint32_t rowNumCurLoopRound, uint32_t columnNum, + uint32_t columnNumRound, uint32_t rowOffset) + { + // *** hm_block = expand_to_block(hm), 存放于 tv + AscendC::Brcb(tvUbTensor.template ReinterpretCast(), + hmUbTensor[rowOffset].template ReinterpretCast(), rowNumCurLoopRound / FLOAT_BLOCK_SIZE, + AscendC::BrcbRepeatParams(1, 8)); + AscendC::PipeBarrier(); + // *** ls = ls - hm_block + for (uint32_t subIdx = 0; subIdx < columnNum / FLOAT_VECTOR_SIZE; ++subIdx) { + AscendC::Sub(lsUbTensor[sUbOffset][subIdx * FLOAT_VECTOR_SIZE], + lsUbTensor[sUbOffset][subIdx * FLOAT_VECTOR_SIZE], tvUbTensor, (uint64_t)0, + rowNumCurLoop, + AscendC::BinaryRepeatParams(1, 1, 0, columnNumRound / FLOAT_BLOCK_SIZE, + columnNumRound / FLOAT_BLOCK_SIZE, 1)); + } + if (columnNum % FLOAT_VECTOR_SIZE > 0) { + SetVecMask(columnNum % FLOAT_VECTOR_SIZE); + AscendC::Sub(lsUbTensor[sUbOffset][columnNum / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + lsUbTensor[sUbOffset][columnNum / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + tvUbTensor, (uint64_t)0, rowNumCurLoop, + AscendC::BinaryRepeatParams(1, 1, 0, columnNumRound / FLOAT_BLOCK_SIZE, + columnNumRound / FLOAT_BLOCK_SIZE, 1)); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + AscendC::PipeBarrier(); + // *** ls = exp(ls) + AscendC::Exp(lsUbTensor[sUbOffset], lsUbTensor[sUbOffset], (uint64_t)0, + (rowNumCurLoop * columnNumRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::UnaryRepeatParams(1, 1, 8, 8)); + AscendC::PipeBarrier(); + } + + CATLASS_DEVICE + void CalcLocalRowSum(uint32_t sUbOffset, uint32_t rowNumCurLoopRound, uint32_t columnNum, uint32_t columnNumRound, + uint32_t rowOffset) + { + // *** ll = rowsum(ls32) + if (columnNum == 512) { + RowsumSPECTILE512(lsUbTensor[sUbOffset], llUbTensor[rowOffset], tvUbTensor, rowNumCurLoopRound, columnNum, + columnNumRound); + } else if (columnNum == 256) { + RowsumSPECTILE256(lsUbTensor[sUbOffset], llUbTensor[rowOffset], tvUbTensor, rowNumCurLoopRound, columnNum, + columnNumRound); + } else { + RowsumTAILTILE(lsUbTensor[sUbOffset], llUbTensor[rowOffset], tvUbTensor, rowNumCurLoopRound, columnNum, + columnNumRound); + } + } + + CATLASS_DEVICE + void UpdateGlobalRowSum(uint32_t sUbOffset, uint32_t rowNumCurLoop, uint32_t rowNumCurLoopRound, + uint32_t dmUbOffsetCurCycle, uint32_t rowOffset, uint32_t isFirstStackTile) + { + if (isFirstStackTile) { + // *** gl = ll + AscendC::DataCopy(glUbTensor[rowOffset], llUbTensor[rowOffset], + AscendC::DataCopyParams(1, rowNumCurLoopRound / FLOAT_BLOCK_SIZE, 0, 0)); + AscendC::PipeBarrier(); + } else { + SetVecMask(rowNumCurLoop); + // *** gl = dm * gl + AscendC::Mul(glUbTensor[rowOffset], dmUbTensor[dmUbOffsetCurCycle], glUbTensor[rowOffset], + (uint64_t)0, 1, AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + // *** gl = ll + gl + AscendC::Add(glUbTensor[rowOffset], glUbTensor[rowOffset], llUbTensor[rowOffset], (uint64_t)0, + 1, AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + } + + CATLASS_DEVICE + void DownCastP(uint32_t sUbOffset, uint32_t rowNumCurLoop, uint32_t columnNumRound) + { + // *** lp = castfp32to16(ls) + if (std::is_same::value) { + AscendC::Cast( + lpUbTensor[sUbOffset], lsUbTensor[sUbOffset], AscendC::RoundMode::CAST_RINT, (uint64_t)0, + (rowNumCurLoop * columnNumRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::UnaryRepeatParams(1, 1, 4, 8)); + } else { + AscendC::Cast( + lpUbTensor[sUbOffset], lsUbTensor[sUbOffset], AscendC::RoundMode::CAST_NONE, (uint64_t)0, + (rowNumCurLoop * columnNumRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::UnaryRepeatParams(1, 1, 4, 8)); + } + } + + CATLASS_DEVICE + void CopyPUbToGm(AscendC::GlobalTensor gOutput, uint32_t sUbOffset, uint32_t rowNumCurLoop, + uint32_t columnNumRound, uint32_t columnNumPad) + { + if (columnNumRound == columnNumPad) { + AscendC::DataCopy(gOutput, lpUbTensor[sUbOffset], + AscendC::DataCopyParams(1, rowNumCurLoop * columnNumRound / BLOCK_SIZE, 0, 0)); + } else { + AscendC::DataCopy(gOutput, lpUbTensor[sUbOffset], + AscendC::DataCopyParams(rowNumCurLoop, columnNumRound / BLOCK_SIZE, 0, + (columnNumPad - columnNumRound) / BLOCK_SIZE)); + } + } + + template + CATLASS_DEVICE void SubCoreCompute(AscendC::GlobalTensor gOutput, const LayoutOutput &layoutOutput, + uint32_t rowOffset, uint32_t isFirstStackTile, uint32_t columnNumRound, + uint32_t pingpongFlag, uint32_t curStackTileMod) + { + uint32_t rowNumCurLoop = layoutOutput.shape(0); + uint32_t rowNumCurLoopRound = RoundUp(rowNumCurLoop, FLOAT_BLOCK_SIZE); + uint32_t columnNum = layoutOutput.shape(1); + // Align colNum to 16, for both float&half compute© + uint32_t columnNumPad = layoutOutput.stride(0); + uint32_t sUbOffset = pingpongFlag * MAX_UB_S_ELEM_NUM; + uint32_t dmUbOffsetCurCycle = curStackTileMod * MAX_ROW_NUM_SUB_CORE + rowOffset; + + CalcLocalRowMax(sUbOffset, rowNumCurLoopRound, columnNum, columnNumRound, rowOffset); + UpdateGlobalRowMax(rowNumCurLoop, rowNumCurLoopRound, columnNum, columnNumRound, dmUbOffsetCurCycle, rowOffset, + isFirstStackTile); + + CalcExp(sUbOffset, rowNumCurLoop, rowNumCurLoopRound, columnNum, columnNumRound, rowOffset); + if constexpr (maskCat == MaskCategory::NO_MASK) { + AscendC::WaitFlag(pingpongFlag); + } + + DownCastP(sUbOffset, rowNumCurLoop, columnNumRound); + AscendC::SetFlag(pingpongFlag); + + CalcLocalRowSum(sUbOffset, rowNumCurLoopRound, columnNum, columnNumRound, rowOffset); + AscendC::SetFlag(pingpongFlag); + + AscendC::WaitFlag(pingpongFlag); + CopyPUbToGm(gOutput, sUbOffset, rowNumCurLoop, columnNumRound, columnNumPad); + if constexpr (maskCat == MaskCategory::NO_MASK) { + AscendC::SetFlag(pingpongFlag); + } else if constexpr (maskCat == MaskCategory::CAUSAL_MASK) { + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + } + UpdateGlobalRowSum(sUbOffset, rowNumCurLoop, rowNumCurLoopRound, dmUbOffsetCurCycle, rowOffset, + isFirstStackTile); + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gOutput, AscendC::GlobalTensor gInput, + const LayoutOutput &layoutOutput, const LayoutInput &layoutInput, GemmCoord actualBlockShape, + uint32_t isFirstStackTile, uint32_t qSBlockSize, uint32_t qNBlockSize, uint32_t curStackTileMod) + { + uint32_t rowNum = actualBlockShape.m(); + uint32_t columnNum = actualBlockShape.n(); + uint32_t columnNumRound = RoundUp(columnNum, BLOCK_SIZE); + uint32_t columnNumPad = layoutInput.stride(0); + + uint32_t subBlockIdx = AscendC::GetSubBlockIdx(); + uint32_t subBlockNum = AscendC::GetSubBlockNum(); + + uint32_t qNSplitSubBlock = qNBlockSize / subBlockNum; + uint32_t qNThisSubBlock = (qNBlockSize == 1) ? 0 + : (subBlockIdx == 1) ? (qNBlockSize - qNSplitSubBlock) + : qNSplitSubBlock; + uint32_t rowSplitSubBlock = (qNBlockSize == 1) ? (qSBlockSize / 2) : (qSBlockSize * qNSplitSubBlock); + uint32_t rowActualThisSubBlock = (subBlockIdx == 1) ? (rowNum - rowSplitSubBlock) : rowSplitSubBlock; + uint32_t rowOffsetThisSubBlock = subBlockIdx * rowSplitSubBlock; + uint32_t maxRowNumPerLoop = MAX_UB_S_ELEM_NUM / columnNumRound; + uint32_t rowNumTile = RoundDown(maxRowNumPerLoop, FLOAT_BLOCK_SIZE); + uint32_t rowLoopNum = CeilDiv(rowActualThisSubBlock, rowNumTile); + uint32_t preLoad = 1; + + for (uint32_t rowLoopIdx = 0; rowLoopIdx < rowLoopNum + preLoad; rowLoopIdx++) { + if (rowLoopIdx < rowLoopNum) { + uint32_t pingpongFlag = rowLoopIdx % 2; + uint32_t rowOffsetCurLoop = rowLoopIdx * rowNumTile; + uint32_t rowOffsetIoGm = rowOffsetCurLoop + rowOffsetThisSubBlock; + uint32_t rowNumCurLoop = + (rowLoopIdx == rowLoopNum - 1) ? (rowActualThisSubBlock - rowOffsetCurLoop) : rowNumTile; + // loop 0 mask load before cross core sync + + int64_t offsetInput = layoutInput.GetOffset(MatrixCoord(rowOffsetIoGm, 0)); + auto gInputCurLoop = gInput[offsetInput]; + + AscendC::WaitFlag(pingpongFlag); + CopySGmToUb(gInputCurLoop, (pingpongFlag * MAX_UB_S_ELEM_NUM), rowNumCurLoop, columnNumRound, + columnNumPad); + AscendC::SetFlag(pingpongFlag); + } + if (rowLoopIdx >= preLoad) { + uint32_t delayedRowLoopIdx = rowLoopIdx - preLoad; + uint32_t pingpongFlag = delayedRowLoopIdx % 2; + uint32_t rowOffsetCurLoop = delayedRowLoopIdx * rowNumTile; + uint32_t rowOffsetIoGm = rowOffsetCurLoop + rowOffsetThisSubBlock; + uint32_t rowNumCurLoop = + (delayedRowLoopIdx == rowLoopNum - 1) ? (rowActualThisSubBlock - rowOffsetCurLoop) : rowNumTile; + + int64_t offsetOutput = layoutOutput.GetOffset(MatrixCoord(rowOffsetIoGm, 0)); + auto gOutputCurLoop = gOutput[offsetOutput]; + auto layoutOutputCurLoop = layoutOutput.GetTileLayout(MatrixCoord(rowNumCurLoop, columnNum)); + AscendC::WaitFlag(pingpongFlag); + ScaleS((pingpongFlag * MAX_UB_S_ELEM_NUM), rowNumCurLoop, columnNumRound); + SubCoreCompute(gOutputCurLoop, layoutOutputCurLoop, rowOffsetCurLoop, + isFirstStackTile, columnNumRound, pingpongFlag, curStackTileMod); + } + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gOutput, AscendC::GlobalTensor gInput, + AscendC::GlobalTensor gMask, const LayoutOutput &layoutOutput, + const LayoutInput &layoutInput, const LayoutInput &layoutMask, GemmCoord actualBlockShape, + uint32_t isFirstStackTile, uint32_t qSBlockSize, uint32_t qNBlockSize, uint32_t curStackTileMod, + Arch::CrossCoreFlag qkReady) + { + uint32_t rowNum = actualBlockShape.m(); + uint32_t columnNum = actualBlockShape.n(); + uint32_t columnNumRound = RoundUp(columnNum, BLOCK_SIZE); + uint32_t columnNumPad = layoutInput.stride(0); + uint32_t maskStride = layoutMask.stride(0); + uint32_t subBlockIdx = AscendC::GetSubBlockIdx(); + uint32_t subBlockNum = AscendC::GetSubBlockNum(); + + uint32_t qNSplitSubBlock = qNBlockSize / subBlockNum; + uint32_t qNThisSubBlock = (qNBlockSize == 1) ? 0 + : (subBlockIdx == 1) ? (qNBlockSize - qNSplitSubBlock) + : qNSplitSubBlock; + uint32_t rowSplitSubBlock = (qNBlockSize == 1) ? (qSBlockSize / 2) : (qSBlockSize * qNSplitSubBlock); + uint32_t rowActualThisSubBlock = (subBlockIdx == 1) ? (rowNum - rowSplitSubBlock) : rowSplitSubBlock; + uint32_t rowOffsetThisSubBlock = subBlockIdx * rowSplitSubBlock; + + uint32_t tokenNumPerHeadThisSubBlock = Min(qSBlockSize, rowActualThisSubBlock); + + uint32_t maskOffsetThisSubBlock = (qNBlockSize == 1) ? rowOffsetThisSubBlock : 0; + int64_t offsetMask = layoutMask.GetOffset(MatrixCoord(maskOffsetThisSubBlock, 0)); + auto gMaskThisSubBlock = gMask[offsetMask]; + auto layoutMaskThisSubBlock = layoutMask; + + uint32_t maxRowNumPerLoop = MAX_UB_S_ELEM_NUM / columnNumRound; + uint32_t rowNumTile = RoundDown(maxRowNumPerLoop, FLOAT_BLOCK_SIZE); + uint32_t rowLoopNum = CeilDiv(rowActualThisSubBlock, rowNumTile); + uint32_t preLoad = 1; + + if (rowActualThisSubBlock == 0) { + Arch::CrossCoreWaitFlag(qkReady); + return; + } + + for (uint32_t rowLoopIdx = 0; rowLoopIdx < rowLoopNum + preLoad; rowLoopIdx++) { + if (rowLoopIdx < rowLoopNum) { + uint32_t pingpongFlag = rowLoopIdx % 2; + uint32_t rowOffsetCurLoop = rowLoopIdx * rowNumTile; + uint32_t rowOffsetIoGm = rowOffsetCurLoop + rowOffsetThisSubBlock; + uint32_t rowNumCurLoop = + (rowLoopIdx == rowLoopNum - 1) ? (rowActualThisSubBlock - rowOffsetCurLoop) : rowNumTile; + // loop 0 mask load before cross core sync + if (rowLoopIdx == 0) { + // the token idx of the start token of the prologue part + uint32_t proTokenIdx = rowOffsetCurLoop % tokenNumPerHeadThisSubBlock; + // the token num of the prologue part + uint32_t proTokenNum = + Min(rowNumCurLoop, (tokenNumPerHeadThisSubBlock - proTokenIdx)) % tokenNumPerHeadThisSubBlock; + // the token num of the epilogue part + uint32_t integralHeadNum = (rowNumCurLoop - proTokenNum) / tokenNumPerHeadThisSubBlock; + // the number of integral heads within a cycle + uint32_t epiTokenNum = rowNumCurLoop - proTokenNum - integralHeadNum * tokenNumPerHeadThisSubBlock; + AscendC::WaitFlag(EVENT_ID2); + CopyMaskGmToUb(gMaskThisSubBlock, columnNum, columnNumRound, maskStride, qSBlockSize, proTokenIdx, + proTokenNum, integralHeadNum, epiTokenNum); + AscendC::SetFlag(EVENT_ID2); + Arch::CrossCoreWaitFlag(qkReady); + } + int64_t offsetInput = layoutInput.GetOffset(MatrixCoord(rowOffsetIoGm, 0)); + auto gInputCurLoop = gInput[offsetInput]; + AscendC::WaitFlag(pingpongFlag); + CopySGmToUb(gInputCurLoop, (pingpongFlag * MAX_UB_S_ELEM_NUM), rowNumCurLoop, columnNumRound, + columnNumPad); + AscendC::SetFlag(pingpongFlag); + } + if (rowLoopIdx >= preLoad) { + uint32_t delayedRowLoopIdx = rowLoopIdx - preLoad; + uint32_t pingpongFlag = delayedRowLoopIdx % 2; + uint32_t rowOffsetCurLoop = delayedRowLoopIdx * rowNumTile; + uint32_t rowNumCurLoop = + (delayedRowLoopIdx == rowLoopNum - 1) ? (rowActualThisSubBlock - rowOffsetCurLoop) : rowNumTile; + + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID2); + UpCastMask(rowNumCurLoop, columnNumRound); + AscendC::WaitFlag(pingpongFlag); + ScaleS((pingpongFlag * MAX_UB_S_ELEM_NUM), rowNumCurLoop, columnNumRound); + ApplyMask((pingpongFlag * MAX_UB_S_ELEM_NUM), rowNumCurLoop, columnNumRound); + AscendC::SetFlag(EVENT_ID2); + // next loop mask load + if (rowLoopIdx < rowLoopNum) { + uint32_t rowOffsetCurLoop = rowLoopIdx * rowNumTile; + uint32_t rowNumCurLoop = + (rowLoopIdx == rowLoopNum - 1) ? (rowActualThisSubBlock - rowOffsetCurLoop) : rowNumTile; + // the token idx of the start token of the prologue part + uint32_t proTokenIdx = rowOffsetCurLoop % tokenNumPerHeadThisSubBlock; + // the token num of the prologue part + uint32_t proTokenNum = + Min(rowNumCurLoop, (tokenNumPerHeadThisSubBlock - proTokenIdx)) % tokenNumPerHeadThisSubBlock; + // the number of integral heads within a cycle + uint32_t integralHeadNum = (rowNumCurLoop - proTokenNum) / tokenNumPerHeadThisSubBlock; + // the token num of the epilogue part + uint32_t epiTokenNum = rowNumCurLoop - proTokenNum - integralHeadNum * tokenNumPerHeadThisSubBlock; + AscendC::WaitFlag(EVENT_ID2); + CopyMaskGmToUb(gMaskThisSubBlock, columnNum, columnNumRound, maskStride, qSBlockSize, proTokenIdx, + proTokenNum, integralHeadNum, epiTokenNum); + AscendC::SetFlag(EVENT_ID2); + } + // online softmax vectorized compute + uint32_t rowOffsetIoGm = rowOffsetCurLoop + rowOffsetThisSubBlock; + int64_t offsetOutput = layoutOutput.GetOffset(MatrixCoord(rowOffsetIoGm, 0)); + auto gOutputCurLoop = gOutput[offsetOutput]; + auto layoutOutputCurLoop = layoutOutput.GetTileLayout(MatrixCoord(rowNumCurLoop, columnNum)); + SubCoreCompute(gOutputCurLoop, layoutOutputCurLoop, rowOffsetCurLoop, + isFirstStackTile, columnNumRound, pingpongFlag, + curStackTileMod); + } + } + } + +private: + float scaleValue; + AscendC::LocalTensor lsUbTensor; + AscendC::LocalTensor lpUbTensor; + AscendC::LocalTensor maskUbTensor; + AscendC::LocalTensor maskUbTensor32; + AscendC::LocalTensor lmUbTensor; + AscendC::LocalTensor hmUbTensor; + AscendC::LocalTensor gmUbTensor; + AscendC::LocalTensor dmUbTensor; + AscendC::LocalTensor llUbTensor; + AscendC::LocalTensor tvUbTensor; + AscendC::LocalTensor glUbTensor; +}; +} // namespace Catlass::Epilogue::Block + +#endif // CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_ONLINE_SOFTMAX_NO_MASK_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_per_token_dequant.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_per_token_dequant.hpp new file mode 100644 index 00000000..859ad59a --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_per_token_dequant.hpp @@ -0,0 +1,594 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP +#define CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/gemm_type.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/detail/callback.hpp" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue, CType_, ScaleType_, PerTokenScaleType_, DType_, + TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_, + EpilogueTileSwizzle_> +{ +public: + using DispatchPolicy = EpilogueAtlasA2PerTokenDequant; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementScale = typename ScaleType_::Element; + using LayoutScale = typename ScaleType_::Layout; + using ElementPerTokenScale = typename PerTokenScaleType_::Element; + using LayoutPerTokenScale = typename PerTokenScaleType_::Layout; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert(std::is_same_v && + (std::is_same_v || std::is_same_v) && + std::is_same_v && std::is_same_v, + "The element type template parameters of BlockEpilogue are wrong"); + static_assert(std::is_same_v && std::is_same_v && + std::is_same_v && + std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong"); + + // Tile compute ops + using TileRowBroadcastMul = TileRowBroadcastMul_; + using TileBroadcastOneBlk = TileBroadcastOneBlk_; + using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; + using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + using EpilogueTileSwizzle = EpilogueTileSwizzle_; + + using TileShape = typename TileRowBroadcastMul::TileShape; + + static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && + std::is_same_v, + "TileShape must be consistent for all tile compute ops"); + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + + TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + + (TileShape::COUNT + TileShape::COLUMN + TileShape::COUNT + TileShape::ROW) * sizeof(float) + + TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, + "TileShape is too large to fit in UB"); + + struct Params { + __gm__ ElementScale *ptrScale{nullptr}; + LayoutScale layoutScale{}; + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + CATLASS_DEVICE + Params() {}; + + CATLASS_DEVICE + Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, + __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_) + : ptrScale(ptrScale_), + layoutScale(layoutScale_), + ptrPerTokenScale(ptrPerTokenScale_), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(ptrD_), + layoutD(layoutD_) + {} + }; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + { + size_t ubOffset = 0; + int32_t eventVMTE2 = 0; + int32_t eventMTE2V = 0; + int32_t eventMTE3V = 0; + int32_t eventVMTE3 = 0; + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementC); + ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementScale); + ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbScaleVMTE2List[i] = eventVMTE2++; + eventUbScaleMTE2VList[i] = eventMTE2V++; + eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; + eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbScaleVMTE2List[i]); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + } + ubCFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubScaleFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(float); + ubMul = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubPerTokenScaleFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(float); + ubPerTokenScaleFp32Brcb = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * BYTE_PER_BLK; + ubPerTokenMul = ubMul; + } + + CATLASS_DEVICE + ~BlockEpilogue() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } + } + + CATLASS_DEVICE + void UpdateParams(Params const ¶ms_) + { + params = params_; + } + + CATLASS_DEVICE + void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK, + GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor const &gmBlockC, + LayoutC const &layoutBlockC, Callback &&callback = Callback{}) + { + if (actualBlockShapeMNK.k() == 0) { + return; + } + callback(); + + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmScale; + gmScale.SetGlobalBuffer(params.ptrScale); + AscendC::GlobalTensor gmPerTokenScale; + gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(params.ptrD); + + auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); + auto tileShape = TileShape::ToCoord(); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = AscendC::GetSubBlockIdx(); + uint32_t subblockNum = AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; + auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); + + auto &ubC = ubCList[ubListId]; + LayoutC layoutUbC{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + + auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); + auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); + + auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; + auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); + + auto &ubScale = ubScaleList[ubListId]; + auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); + + AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); + copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); + AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); + + auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); + auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>(); + + auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)]; + auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); + + auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; + auto layoutUbPerTokenScale = + LayoutScale::template MakeLayoutInUb(perTokenScaleTileShape); + + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale, + layoutGmTilePerTokenScale); + AscendC::SetFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + + AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); + AscendC::Cast(ubScaleFp32, ubScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN); + AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); + + AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + AscendC::Cast(ubPerTokenScaleFp32, ubPerTokenScale, AscendC::RoundMode::CAST_NONE, TileShape::ROW); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + + AscendC::PipeBarrier(); + tileRowBroadcastMul(ubMul, ubCFp32, ubScaleFp32); + tileBroadcastOneBlk(ubPerTokenScaleFp32Brcb, ubPerTokenScaleFp32); + AscendC::PipeBarrier(); + tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleFp32Brcb); + AscendC::PipeBarrier(); + + auto &ubD = ubDList[ubListId]; + LayoutD layoutUbD{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + + auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)]; + auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); + + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + } + +private: + Params params; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubScaleList[UB_STAGES]; + AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbScaleVMTE2List[UB_STAGES]; + int32_t eventUbScaleMTE2VList[UB_STAGES]; + int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; + int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + uint32_t ubListId{0}; + + AscendC::LocalTensor ubCFp32; + AscendC::LocalTensor ubScaleFp32; + AscendC::LocalTensor ubMul; + AscendC::LocalTensor ubPerTokenScaleFp32; + AscendC::LocalTensor ubPerTokenScaleFp32Brcb; + AscendC::LocalTensor ubPerTokenMul; + + TileRowBroadcastMul tileRowBroadcastMul; + TileBroadcastOneBlk tileBroadcastOneBlk; + TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; + + CopyGmToUbC copyGmToUbC; + CopyGmToUbScale copyGmToUbScale; + CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; + CopyUbToGmD copyUbToGmD; +}; + +template +class BlockEpilogue, CType_, Gemm::GemmType, + Gemm::GemmType, DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, + TileOneBlkColumnBroadcastMul_, TileCopy_, EpilogueTileSwizzle_> +{ +public: + using DispatchPolicy = EpilogueAtlasA2PerTokenDequant; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementScale = float; + using LayoutScale = LayoutScale_; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert(std::is_same_v && + (std::is_same_v || std::is_same_v), + "The element type template parameters of BlockEpilogue are wrong"); + static_assert(std::is_same_v && std::is_same_v && + std::is_same_v && + std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong"); + + // Tile compute ops + using TileRowBroadcastMul = TileRowBroadcastMul_; + using TileBroadcastOneBlk = TileBroadcastOneBlk_; + using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; + using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + using EpilogueTileSwizzle = EpilogueTileSwizzle_; + + using TileShape = typename TileRowBroadcastMul::TileShape; + + static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && + std::is_same_v, + "TileShape must be consistent for all tile compute ops"); + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + + TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + + (TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <= + ArchTag::UB_SIZE, + "TileShape is too large to fit in UB"); + + struct Params { + __gm__ ElementScale *ptrScale{nullptr}; + LayoutScale layoutScale{}; + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + CATLASS_DEVICE + Params() {}; + + CATLASS_DEVICE + Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, + __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_) + : ptrScale(ptrScale_), + layoutScale(layoutScale_), + ptrPerTokenScale(ptrPerTokenScale_), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(ptrD_), + layoutD(layoutD_) + {} + }; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + { + size_t ubOffset = 0; + int32_t eventVMTE2 = 0; + int32_t eventMTE2V = 0; + int32_t eventMTE3V = 0; + int32_t eventVMTE3 = 0; + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementC); + ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementScale); + ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbScaleVMTE2List[i] = eventVMTE2++; + eventUbScaleMTE2VList[i] = eventMTE2V++; + eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; + eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbScaleVMTE2List[i]); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + } + ubCFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubMul = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubPerTokenScaleBrcb = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * BYTE_PER_BLK; + ubPerTokenMul = ubCFp32; + } + + CATLASS_DEVICE + ~BlockEpilogue() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } + } + + CATLASS_DEVICE + void UpdateParams(Params const ¶ms_) + { + params = params_; + } + + CATLASS_DEVICE + void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK, + GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor const &gmBlockC, + LayoutC const &layoutBlockC, Callback &&callback = Callback{}) + { + if (actualBlockShapeMNK.k() == 0) { + return; + } + callback(); + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmScale; + gmScale.SetGlobalBuffer(params.ptrScale); + AscendC::GlobalTensor gmPerTokenScale; + gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(params.ptrD); + + auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); + auto tileShape = TileShape::ToCoord(); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = AscendC::GetSubBlockIdx(); + uint32_t subblockNum = AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; + auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); + + auto &ubC = ubCList[ubListId]; + LayoutC layoutUbC{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + + auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); + auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); + + auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; + auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); + + auto &ubScale = ubScaleList[ubListId]; + auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); + + AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); + copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); + AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); + + auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); + auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>(); + + auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)]; + auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); + + auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; + auto layoutUbPerTokenScale = + LayoutScale::template MakeLayoutInUb(perTokenScaleTileShape); + + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale, + layoutGmTilePerTokenScale); + AscendC::SetFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + + AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); + tileRowBroadcastMul(ubMul, ubCFp32, ubScale); + AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); + + AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + tileBroadcastOneBlk(ubPerTokenScaleBrcb, ubPerTokenScale); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + + AscendC::PipeBarrier(); + tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleBrcb); + AscendC::PipeBarrier(); + + auto &ubD = ubDList[ubListId]; + LayoutD layoutUbD{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + + auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)]; + auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); + + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + } + +private: + Params params; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubScaleList[UB_STAGES]; + AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbScaleVMTE2List[UB_STAGES]; + int32_t eventUbScaleMTE2VList[UB_STAGES]; + int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; + int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + uint32_t ubListId{0}; + + AscendC::LocalTensor ubCFp32; + AscendC::LocalTensor ubMul; + AscendC::LocalTensor ubPerTokenScaleBrcb; + AscendC::LocalTensor ubPerTokenMul; + + TileRowBroadcastMul tileRowBroadcastMul; + TileBroadcastOneBlk tileBroadcastOneBlk; + TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; + + CopyGmToUbC copyGmToUbC; + CopyGmToUbScale copyGmToUbScale; + CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; + CopyUbToGmD copyUbToGmD; +}; + +} // namespace Catlass::Epilogue::Block + +#endif // CATLASS_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_rescale_o_no_split_row.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_rescale_o_no_split_row.hpp new file mode 100644 index 00000000..26dfe391 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/block/block_epilogue_rescale_o_no_split_row.hpp @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_RESCALE_O_NO_SPLIT_ROW_HPP +#define CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_RESCALE_O_NO_SPLIT_ROW_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/epilogue/tile/tile_copy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue +{ +public: + // Type aliases + using DispatchPolicy = EpilogueAtlasA2RescaleO; + using ArchTag = typename DispatchPolicy::ArchTag; + + using ElementOutput = typename OutputType_::Element; + using ElementInput = typename InputType_::Element; + using ElementUpdate = typename UpdateType_::Element; + + using LayoutOutput = typename OutputType_::Layout; + using LayoutInput = typename InputType_::Layout; + using LayoutUpdate = typename UpdateType_::Layout; + + static constexpr uint32_t HALF_ELENUM_PER_BLK = 16; + static constexpr uint32_t BLOCK_SIZE = 16; + static constexpr uint32_t HALF_ELENUM_PER_VECCALC = 128; + static constexpr uint32_t FLOAT_ELENUM_PER_VECCALC = 64; + static constexpr uint32_t HALF_ELENUM_PER_LINE = 256; + static constexpr uint32_t FLOAT_ELENUM_PER_LINE = 128; + static constexpr uint32_t MULTIPLIER = 2; + static constexpr uint32_t FLOAT_BLOCK_SIZE = 8; + static constexpr uint32_t FLOAT_VECTOR_SIZE = 64; + static constexpr uint32_t UB_UINT8_VECTOR_SIZE = 1024; + static constexpr uint32_t UB_UINT8_BLOCK_SIZE = 16384; + static constexpr uint32_t HALF_DM_UB_SIZE = 64; + static constexpr uint32_t HALF_LL_UB_SIZE = 256; + static constexpr uint32_t VECTOR_SIZE = 128; + static constexpr uint32_t NUM4 = 4; + static constexpr uint32_t MAX_UB_O_ELEM_NUM = 4096; + static constexpr uint32_t MAX_ROW_NUM_SUB_CORE = 128; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource &resource) + { + // Allocate UB space + constexpr uint32_t LO_UB_TENSOR_OFFSET = 6 * UB_UINT8_BLOCK_SIZE; + constexpr uint32_t GO_UB_TENSOR_OFFSET = 8 * UB_UINT8_BLOCK_SIZE; + constexpr uint32_t TV_UB_TENSOR_OFFSET = 10 * UB_UINT8_BLOCK_SIZE; + + constexpr uint32_t HM_UB_TENSOR_OFFSET = 10 * UB_UINT8_BLOCK_SIZE + 9 * UB_UINT8_VECTOR_SIZE; + constexpr uint32_t GL_UB_TENSOR_OFFSET = 10 * UB_UINT8_BLOCK_SIZE + 12 * UB_UINT8_VECTOR_SIZE; + constexpr uint32_t DM_UB_TENSOR_OFFSET = 10 * UB_UINT8_BLOCK_SIZE + 13 * UB_UINT8_VECTOR_SIZE; + + loUbTensor = resource.ubBuf.template GetBufferByByte(LO_UB_TENSOR_OFFSET); + dmUbTensor = resource.ubBuf.template GetBufferByByte(DM_UB_TENSOR_OFFSET); + glUbTensor = resource.ubBuf.template GetBufferByByte(GL_UB_TENSOR_OFFSET); + tvUbTensor = resource.ubBuf.template GetBufferByByte(TV_UB_TENSOR_OFFSET); + goUbTensor16 = resource.ubBuf.template GetBufferByByte(GO_UB_TENSOR_OFFSET); + goUbTensor32 = resource.ubBuf.template GetBufferByByte(GO_UB_TENSOR_OFFSET); + hmUbTensor = resource.ubBuf.template GetBufferByByte(HM_UB_TENSOR_OFFSET); + } + + CATLASS_DEVICE + ~BlockEpilogue() {} + + CATLASS_DEVICE + void SetMask(int32_t len) + { + uint64_t mask = 0; + uint64_t one = 1; + uint64_t temp = len % FLOAT_VECTOR_SIZE; + for (int64_t i = 0; i < temp; i++) { + mask |= one << i; + } + + if (len == VECTOR_SIZE) { + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } else if (len >= FLOAT_VECTOR_SIZE) { + AscendC::SetVectorMask(mask, (uint64_t)-1); + } else { + AscendC::SetVectorMask(0x0, mask); + } + } + + CATLASS_DEVICE + void CopyOToGm(AscendC::GlobalTensor gOutput, uint32_t curRowNum, uint32_t qSBlockSize, + uint32_t embed, uint32_t embedRound, uint32_t qNThisSubBlock, uint32_t oHiddenSize) + { + if (qNThisSubBlock == 0) { + AscendC::DataCopyPad(gOutput, goUbTensor16, + AscendC::DataCopyExtParams(curRowNum, embed * 2, 0, (oHiddenSize - embed) * 2, 0)); + } else { + for (uint32_t qNIdx = 0; qNIdx < qNThisSubBlock; qNIdx++) { + AscendC::DataCopyPad( + gOutput[qNIdx * embed], goUbTensor16[qNIdx * embedRound * qSBlockSize], + AscendC::DataCopyExtParams(qSBlockSize, embed * 2, 0, (oHiddenSize - embed) * 2, 0)); + } + } + } + + CATLASS_DEVICE + void SubCoreCompute(AscendC::GlobalTensor gOutput, AscendC::GlobalTensor gInput, + const LayoutOutput &layoutOutput, const LayoutInput &layoutInput, uint32_t qNThisSubBlock, + uint32_t isFirstStackTile, uint32_t isLastStackTile, uint32_t curStackTileMod) + { + uint32_t curRowNum = layoutInput.shape(0); + uint32_t embed = layoutInput.shape(1); + uint32_t embedRound = layoutInput.stride(0); + uint32_t curRowNumRound = RoundUp(curRowNum, FLOAT_BLOCK_SIZE); + uint32_t qSBlockSize = layoutOutput.shape(0); + uint32_t oHiddenSize = layoutOutput.shape(1); + uint32_t dmUbOffsetCurStackTile = curStackTileMod * MAX_ROW_NUM_SUB_CORE; + + AscendC::WaitFlag(EVENT_ID3); + if (!isFirstStackTile) { + AscendC::DataCopy(loUbTensor, gInput, + AscendC::DataCopyParams(1, curRowNum * embedRound / FLOAT_BLOCK_SIZE, 0, 0)); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + AscendC::Brcb(tvUbTensor.ReinterpretCast(), + dmUbTensor[dmUbOffsetCurStackTile].ReinterpretCast(), + curRowNumRound / FLOAT_BLOCK_SIZE, AscendC::BrcbRepeatParams(1, 8)); + AscendC::PipeBarrier(); + // *** go = go * dm_block + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + for (uint32_t vmul_idx = 0; vmul_idx < embed / FLOAT_VECTOR_SIZE; ++vmul_idx) { + AscendC::Mul(goUbTensor32[vmul_idx * FLOAT_VECTOR_SIZE], + goUbTensor32[vmul_idx * FLOAT_VECTOR_SIZE], tvUbTensor, (uint64_t)0, + curRowNum, + AscendC::BinaryRepeatParams(1, 1, 0, embedRound / FLOAT_BLOCK_SIZE, + embedRound / FLOAT_BLOCK_SIZE, 1)); + } + if (embed % FLOAT_VECTOR_SIZE > 0) { + SetMask(embed % FLOAT_VECTOR_SIZE); + AscendC::Mul(goUbTensor32[embed / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + goUbTensor32[embed / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], tvUbTensor, + (uint64_t)0, curRowNum, + AscendC::BinaryRepeatParams(1, 1, 0, embedRound / FLOAT_BLOCK_SIZE, + embedRound / FLOAT_BLOCK_SIZE, 1)); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + AscendC::PipeBarrier(); + // *** go = lo + go + AscendC::Add(goUbTensor32, goUbTensor32, loUbTensor, (uint64_t)0, + (curRowNum * embedRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::BinaryRepeatParams(1, 1, 1, 8, 8, 8)); + AscendC::PipeBarrier(); + } else { + // *** go = lo + AscendC::DataCopy(goUbTensor32, gInput, + AscendC::DataCopyParams(1, curRowNum * embedRound / FLOAT_BLOCK_SIZE, 0, 0)); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + AscendC::SetFlag(EVENT_ID3); + + if (isLastStackTile) { + // *** gl_block = expand_to_block(gl), 存放于 tv + AscendC::Brcb(tvUbTensor.ReinterpretCast(), glUbTensor.ReinterpretCast(), + curRowNumRound / FLOAT_BLOCK_SIZE, AscendC::BrcbRepeatParams(1, 8)); + AscendC::PipeBarrier(); + // *** go = go / gl_block + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + for (uint32_t vdiv_idx = 0; vdiv_idx < embed / FLOAT_VECTOR_SIZE; ++vdiv_idx) { + AscendC::Div(goUbTensor32[vdiv_idx * FLOAT_VECTOR_SIZE], + goUbTensor32[vdiv_idx * FLOAT_VECTOR_SIZE], tvUbTensor, (uint64_t)0, + curRowNum, + AscendC::BinaryRepeatParams(1, 1, 0, embedRound / FLOAT_BLOCK_SIZE, + embedRound / FLOAT_BLOCK_SIZE, 1)); + } + if (embed % FLOAT_VECTOR_SIZE > 0) { + SetMask(embed % FLOAT_VECTOR_SIZE); + AscendC::Div(goUbTensor32[embed / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], + goUbTensor32[embed / FLOAT_VECTOR_SIZE * FLOAT_VECTOR_SIZE], tvUbTensor, + (uint64_t)0, curRowNum, + AscendC::BinaryRepeatParams(1, 1, 0, embedRound / FLOAT_BLOCK_SIZE, + embedRound / FLOAT_BLOCK_SIZE, 1)); + AscendC::SetVectorMask((uint64_t)-1, (uint64_t)-1); + } + AscendC::PipeBarrier(); + + // *** go = castfp32to16(go) + if (std::is_same::value) { + AscendC::Cast( + goUbTensor16, goUbTensor32, AscendC::RoundMode::CAST_RINT, (uint64_t)0, + (curRowNum * embedRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::UnaryRepeatParams(1, 1, 4, 8)); + } else { + AscendC::Cast( + goUbTensor16, goUbTensor32, AscendC::RoundMode::CAST_NONE, (uint64_t)0, + (curRowNum * embedRound + FLOAT_VECTOR_SIZE - 1) / FLOAT_VECTOR_SIZE, + AscendC::UnaryRepeatParams(1, 1, 4, 8)); + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + // ***move O to GM + CopyOToGm(gOutput, curRowNum, qSBlockSize, embed, embedRound, qNThisSubBlock, oHiddenSize); + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gOutput, AscendC::GlobalTensor gInput, + const LayoutOutput &layoutOutput, const LayoutInput &layoutInput, GemmCoord actualBlockShape, + uint32_t qSBlockSize, uint32_t qNBlockSize, uint32_t isFirstStackTile, uint32_t isLastStackTile, + uint32_t curStackTileMod) + { + uint32_t rowNum = actualBlockShape.m(); + uint32_t embed = actualBlockShape.n(); + // uint32_t columnNumRound = layoutInput.stride(0); + + uint32_t subBlockIdx = AscendC::GetSubBlockIdx(); + uint32_t subBlockNum = AscendC::GetSubBlockNum(); + + uint32_t qNSplitSubBlock = qNBlockSize / subBlockNum; + uint32_t qNThisSubBlock = (qNBlockSize == 1) ? 0 + : (subBlockIdx == 1) ? (qNBlockSize - qNSplitSubBlock) + : qNSplitSubBlock; + uint32_t inRowSplitSubBlock = + (qNBlockSize == 1) ? (qSBlockSize / subBlockNum) : (qSBlockSize * qNSplitSubBlock); + uint32_t inRowActualThisSubBlock = (subBlockIdx == 1) ? (rowNum - inRowSplitSubBlock) : inRowSplitSubBlock; + uint32_t inRowOffsetThisSubBlock = subBlockIdx * inRowSplitSubBlock; + uint32_t outRowOffsetThisSubBlock = (qNBlockSize == 1) ? inRowOffsetThisSubBlock : 0; + uint32_t outColOffsetThisSubBlock = (qNBlockSize == 1) ? 0 : subBlockIdx * qNSplitSubBlock * embed; + + if (inRowActualThisSubBlock > 0) { + int64_t offsetOutput = + layoutOutput.GetOffset(MatrixCoord(outRowOffsetThisSubBlock, outColOffsetThisSubBlock)); + auto gOutputThisSubBlock = gOutput[offsetOutput]; + auto layoutOutputThisSubBlock = layoutOutput; + + int64_t offsetInput = layoutInput.GetOffset(MatrixCoord(inRowOffsetThisSubBlock, 0)); + auto gInputThisSubBlock = gInput[offsetInput]; + auto layoutInputThisSubBlock = layoutInput.GetTileLayout(MatrixCoord(inRowActualThisSubBlock, embed)); + SubCoreCompute(gOutputThisSubBlock, gInputThisSubBlock, layoutOutputThisSubBlock, layoutInputThisSubBlock, + qNThisSubBlock, isFirstStackTile, isLastStackTile, curStackTileMod); + } + } + +private: + AscendC::LocalTensor loUbTensor; + AscendC::LocalTensor dmUbTensor; + AscendC::LocalTensor hmUbTensor; + AscendC::LocalTensor glUbTensor; + AscendC::LocalTensor tvUbTensor; + AscendC::LocalTensor goUbTensor16; + AscendC::LocalTensor goUbTensor32; +}; +} // namespace Catlass::Epilogue::Block + +#endif // CATLASS_EPILOGUE_BLOCK_BLOCK_EPILOGUE_RESCALE_O_NO_SPLIT_ROW_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/dispatch_policy.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/dispatch_policy.hpp new file mode 100644 index 00000000..257951c6 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/dispatch_policy.hpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_DISPATCH_POLICY_HPP +#define CATLASS_EPILOGUE_DISPATCH_POLICY_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" + +namespace Catlass::Epilogue { + +// For AtlasA2, an element wise epilogue of the form D = C + X, where X is an additional source +struct EpilogueAtlasA2ElemWiseOneSource { + using ArchTag = Arch::AtlasA2; + // Number of operands. Including C, X, and D 3 operands + static constexpr uint32_t OPERANDS_NUM = 3; +}; + +struct EpilogueAtlasA2ElemWiseNoSource { + using ArchTag = Arch::AtlasA2; + // Number of operands. Including C, D 2 operands + static constexpr uint32_t OPERANDS_NUM = 2; +}; + +// For AtlasA2, FA Softmax +struct EpilogueAtlasA2FASoftmax { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, FA RescaleO +struct EpilogueAtlasA2FARescaleO { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, MLA Softmax +struct EpilogueAtlasA2MLASoftmax { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, FA Infer online Softmax no mask +struct EpilogueAtlasA2OnlineSoftmax { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, FA Infer RescaleO no split row +struct EpilogueAtlasA2RescaleO { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, MLA RescaleO +struct EpilogueAtlasA2MLARescaleO { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, MLA FD RescaleO +template +struct EpilogueAtlasA2MLAFDRescaleO { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t KV_SPLIT_MAX = 64; + static constexpr uint32_t HEADS_PROCESS_MAX = 16; + static constexpr uint32_t COMPUTE_ELE_NUM = COMPUTE_ELE_NUM_; +}; + +// For AtlasA2, MLA TP1 Softmax +struct EpilogueAtlasA2MLATP1Softmax { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, MLA TP1 RescaleO +struct EpilogueAtlasA2MLATP1RescaleO { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, per token dequant +template +struct EpilogueAtlasA2PerTokenDequant { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; +}; +//////////////////////////// +/// new add +// For AtlasA2, GEMM +struct EpilogueAtlasA2Gemm { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, GEMV +struct EpilogueAtlasA2Gemv { + using ArchTag = Arch::AtlasA2; +}; +/////////////////////////// +} // namespace Catlass::Epilogue + +#endif // CATLASS_EPILOGUE_DISPATCH_POLICY_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/copy_gm_to_ub.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/copy_gm_to_ub.hpp new file mode 100644 index 00000000..786124dd --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/copy_gm_to_ub.hpp @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_TILE_TILE_COPY_GM_TO_UB_HPP +#define CATLASS_EPILOGUE_TILE_TILE_COPY_GM_TO_UB_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/gemm/gemm_type.hpp" + +namespace Catlass::Epilogue::Tile { + +template +struct CopyGm2Ub { + static_assert(DEPENDENT_FALSE, "Unsupported copy gm to ub, can not find the specialization."); +}; + +template +struct CopyGm2Ub> { + using LayoutSrc = layout::RowMajor; + using LayoutDst = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + + CATLASS_DEVICE + CopyGm2Ub() = default; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + layout::RowMajor const &layoutDst, layout::RowMajor const &layoutSrc) + { + AscendC::DataCopyExtParams dataCopyParams(layoutSrc.shape(0), layoutSrc.shape(1) * sizeof(Element), + (layoutSrc.stride(0) - layoutSrc.shape(1)) * sizeof(Element), + (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK, 0); + AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams, padParams); + }; +}; + +template +struct CopyGm2Ub> { + using LayoutSrc = layout::VectorLayout; + using LayoutDst = layout::VectorLayout; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + + CATLASS_DEVICE + CopyGm2Ub() = default; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + layout::VectorLayout const &layoutDst, layout::VectorLayout const &layoutSrc) + { + AscendC::DataCopyExtParams dataCopyParams(1, layoutSrc.shape(0) * sizeof(Element), 0, 0, 0); + AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams, padParams); + }; +}; + +/// @brief This copy instruction used to copy per token scale from GM to UB. +/// Copy the scale of shape (m,1) on GM to the first column of shape (m,n) on UB, +/// and pad the first block of each row (i.e. pad to shape (m,8) when element type is float). +/// @tparam ArchTag: Architecture tag. +/// @tparam GmType: Type of data on GM. +template +struct CopyPerTokenScale2Ub { + static_assert(std::is_same_v, + "Unsupported layout for CopyPerTokenScale2Ub."); + + using Element = typename GmType::Element; + using LayoutSrc = typename GmType::Layout; + using LayoutDst = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + + CATLASS_DEVICE + CopyPerTokenScale2Ub() = default; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::DataCopyExtParams dataCopyParams; + AscendC::DataCopyPadExtParams padParams; + + dataCopyParams.blockCount = layoutSrc.shape(0); + dataCopyParams.blockLen = layoutSrc.shape(1) * sizeof(Element); // per token scale has only one column + dataCopyParams.srcStride = 0; + dataCopyParams.dstStride = (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; + // Pad the data to the complete block + padParams.isPad = true; + padParams.leftPadding = 0; + padParams.rightPadding = 0; + + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams, padParams); + } +}; + +template +struct CopyGm2UbAligned { + static_assert(DEPENDENT_FALSE, "Unsupported copy gm to ub aligned, can not find the specialization."); +}; + +template +struct CopyGm2UbAligned> { + using LayoutSrc = layout::RowMajor; + using LayoutDst = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + static constexpr uint32_t BLOCK_LEN_LIMIT = 65536; + static constexpr uint32_t MAX_REPEAT = 4095; + static constexpr uint32_t STRIDE_LIMIT = 65536; + + CATLASS_DEVICE + CopyGm2UbAligned() = default; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + layout::RowMajor const &layoutDst, layout::RowMajor const &layoutSrc) + { + uint32_t rows = layoutSrc.shape(0); + uint32_t cols = layoutSrc.shape(1); + uint32_t srcStride = (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_BLK; + uint32_t dstStride = (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; + + if ((layoutSrc.shape(1) == layoutSrc.stride(0)) && (layoutDst.shape(1) == layoutDst.stride(0))) { + DataCopy(dstTensor, srcTensor, rows * cols); + } else if (srcStride < STRIDE_LIMIT && dstStride < STRIDE_LIMIT && (cols / ELE_NUM_PER_BLK) < BLOCK_LEN_LIMIT) { + uint32_t rLoops = CeilDiv(rows, MAX_REPEAT); + for (uint32_t i = 0; i < rLoops; ++i) { + uint32_t rActual = (i < rLoops - 1) ? MAX_REPEAT : rows - i * MAX_REPEAT; + AscendC::DataCopyParams dataCopyParams(rActual, cols / ELE_NUM_PER_BLK, srcStride, dstStride); + DataCopy(dstTensor[i * MAX_REPEAT * layoutDst.stride(0)], + srcTensor[i * MAX_REPEAT * layoutSrc.stride(0)], dataCopyParams); + } + } else { + for (uint32_t i = 0; i < rows; ++i) { + DataCopy(dstTensor[i * layoutDst.stride(0)], srcTensor[i * layoutSrc.stride(0)], cols); + } + } + }; +}; + +} // namespace Catlass::Epilogue::Tile + +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/copy_ub_to_gm.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/copy_ub_to_gm.hpp new file mode 100644 index 00000000..0f307a1a --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/copy_ub_to_gm.hpp @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_TILE_TILE_COPY_UB_TO_GM_HPP +#define CATLASS_EPILOGUE_TILE_TILE_COPY_UB_TO_GM_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/gemm/gemm_type.hpp" + +namespace Catlass::Epilogue::Tile { + +template +struct CopyUb2Gm { + static_assert(DEPENDENT_FALSE, "Unsupported copy ub to gm, can not find the specialization."); +}; + +template +struct CopyUb2Gm> { + using LayoutDst = layout::RowMajor; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + CATLASS_DEVICE + CopyUb2Gm() = default; + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + layout::RowMajor const &layoutDst, layout::RowMajor const &layoutSrc) + { + AscendC::DataCopyExtParams dataCopyParams(layoutDst.shape(0), layoutDst.shape(1) * sizeof(Element), + (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_C0, + (layoutDst.stride(0) - layoutDst.shape(1)) * sizeof(Element), 0); + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams); + } +}; + +// new add vectorlayout version +template +struct CopyUb2Gm> { + using LayoutSrc = layout::VectorLayout; + using LayoutDst = layout::VectorLayout; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + + CATLASS_DEVICE + CopyUb2Gm() = default; + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + layout::VectorLayout const &layoutDst, layout::VectorLayout const &layoutSrc) + { + AscendC::DataCopyExtParams dataCopyParams(1, layoutDst.shape(0) * sizeof(Element), 0, 0, 0); + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams); + }; +}; + +template +struct CopyUb2GmAligned { + static_assert(DEPENDENT_FALSE, "Unsupported copy ub to gm aligned, can not find the specialization."); +}; + +template +struct CopyUb2GmAligned> { + using LayoutSrc = layout::RowMajor; + using LayoutDst = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + static constexpr uint32_t BLOCK_LEN_LIMIT = 65536; + static constexpr uint32_t MAX_REPEAT = 4095; + static constexpr uint32_t STRIDE_LIMIT = 65536; + + CATLASS_DEVICE + CopyUb2GmAligned() = default; + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + layout::RowMajor const &layoutDst, layout::RowMajor const &layoutSrc) + { + uint32_t rows = layoutDst.shape(0); + uint32_t cols = layoutDst.shape(1); + uint32_t srcStride = (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_BLK; + uint32_t dstStride = (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; + + if ((layoutSrc.shape(1) == layoutSrc.stride(0)) && (layoutDst.shape(1) == layoutDst.stride(0))) { + DataCopy(dstTensor, srcTensor, rows * cols); + } else if (srcStride < STRIDE_LIMIT && dstStride < STRIDE_LIMIT && (cols / ELE_NUM_PER_BLK) < BLOCK_LEN_LIMIT) { + uint32_t rLoops = CeilDiv(rows, MAX_REPEAT); + for (uint32_t i = 0; i < rLoops; ++i) { + uint32_t rActual = (i < rLoops - 1) ? MAX_REPEAT : rows - i * MAX_REPEAT; + AscendC::DataCopyParams dataCopyParams(rActual, cols / ELE_NUM_PER_BLK, srcStride, dstStride); + DataCopy(dstTensor[i * MAX_REPEAT * layoutDst.stride(0)], + srcTensor[i * MAX_REPEAT * layoutSrc.stride(0)], dataCopyParams); + } + } else { + for (uint32_t i = 0; i < rows; ++i) { + DataCopy(dstTensor[i * layoutDst.stride(0)], srcTensor[i * layoutSrc.stride(0)], cols); + } + } + }; +}; + +} // namespace Catlass::Epilogue::Tile + +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_broadcast_inplace_by_column.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_broadcast_inplace_by_column.hpp new file mode 100644 index 00000000..45811ee6 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_broadcast_inplace_by_column.hpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_TILE_TILE_BROADCAST_INPLACE_BY_COLUMN_HPP +#define CATLASS_EPILOGUE_TILE_TILE_BROADCAST_INPLACE_BY_COLUMN_HPP + +#include "catlass/catlass.hpp" + +namespace Catlass::Epilogue::Tile { + +template < + /// Tag indicating architecture + class ArchTag_, + /// Compute data type + class ComputeType_, + /// Length of the compute buffer + class TileShape_> +struct TileBroadcastInplaceByColumn { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + CATLASS_DEVICE + TileBroadcastInplaceByColumn() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &ubInOut) + { + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + constexpr uint32_t blkNumPerRow = TileShape::COLUMN / eleNumPerBlk; + + constexpr uint64_t defaultMask = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); + constexpr uint64_t tailMask = (TileShape::ROW % BLK_NUM_PER_VECTOR_FRACTAL) * eleNumPerBlk; + + constexpr uint8_t repeatTimes = 1; + + AscendC::CopyRepeatParams repeatParams; + repeatParams.dstStride = blkNumPerRow; + repeatParams.srcStride = blkNumPerRow; + repeatParams.dstRepeatSize = 1; + repeatParams.srcRepeatSize = 1; + + for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; rowOffset += BLK_NUM_PER_VECTOR_FRACTAL) { + uint64_t mask = ((TileShape::ROW - rowOffset) >= BLK_NUM_PER_VECTOR_FRACTAL) ? defaultMask : tailMask; + for (uint32_t colOffset = eleNumPerBlk; colOffset < TileShape::COLUMN; colOffset += eleNumPerBlk) { + AscendC::Copy(ubInOut[rowOffset * TileShape::COLUMN + colOffset], + ubInOut[rowOffset * TileShape::COLUMN], mask, 1, repeatParams); + } + } + } +}; + +} // namespace Catlass::Epilogue::Tile + +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_broadcast_inplace_by_row.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_broadcast_inplace_by_row.hpp new file mode 100644 index 00000000..6740abf1 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_broadcast_inplace_by_row.hpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_TILE_TILE_BROADCAST_INPLACE_BY_ROW_HPP +#define CATLASS_EPILOGUE_TILE_TILE_BROADCAST_INPLACE_BY_ROW_HPP + +#include "catlass/catlass.hpp" + +namespace Catlass::Epilogue::Tile { + +template < + /// Tag indicating architecture + class ArchTag_, + /// Compute data type + class ComputeType_, + /// Length of the compute buffer + class TileShape_> +struct TileBroadcastInplaceByRow { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + CATLASS_DEVICE + TileBroadcastInplaceByRow() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &ubInOut) + { + constexpr uint32_t eleNumPerVectorFractal = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); + + constexpr uint64_t mask = eleNumPerVectorFractal; + constexpr uint8_t repeatTimes = TileShape::COLUMN / eleNumPerVectorFractal; + + AscendC::CopyRepeatParams repeatParams; + repeatParams.dstStride = 1; + repeatParams.srcStride = 1; + repeatParams.dstRepeatSize = BLK_NUM_PER_VECTOR_FRACTAL; + repeatParams.srcRepeatSize = BLK_NUM_PER_VECTOR_FRACTAL; + + for (uint32_t rowOffset = 1; rowOffset < TileShape::ROW; ++rowOffset) { + AscendC::Copy(ubInOut[rowOffset * TileShape::COLUMN], ubInOut, mask, repeatTimes, repeatParams); + } + } +}; + +} // namespace Catlass::Epilogue::Tile + +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_broadcast_mul.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_broadcast_mul.hpp new file mode 100644 index 00000000..47606f08 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_broadcast_mul.hpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_TILE_TILE_BROADCAST_MUL_HPP +#define CATLASS_EPILOGUE_TILE_TILE_BROADCAST_MUL_HPP + +#include "catlass/catlass.hpp" + +namespace Catlass::Epilogue::Tile { + +/// BroadcastMul computes the elementwise multiplication of a tensor of shape (m, n) and a tensor +/// of shape (m, n) after broadcasting. There are two broadcast modes: row-broadcast and +/// column-broadcast. + +/// @brief Computes the elementwise multiplication of a tensor with shape (m, n) and a tensor with +/// original shape (1, n) broadcast to (m, n). +/// @tparam ArchTag_ is the architecture tag. +/// @tparam ComputeType_ includes the element type and layout information. +/// @tparam TileShape_ is the shape (m, n). +template +struct TileRowBroadcastMul { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + CATLASS_DEVICE + TileRowBroadcastMul() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, + AscendC::LocalTensor const &ubIn0, + AscendC::LocalTensor const &ubIn1) + { + constexpr uint32_t maxRepeatTimes = 255; + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + + constexpr uint32_t blkNumPerColumn = TileShape::COLUMN / eleNumPerBlk; + AscendC::BinaryRepeatParams repeatParams; + repeatParams.dstBlkStride = 1; + repeatParams.src0BlkStride = 1; + repeatParams.src1BlkStride = 1; + repeatParams.dstRepStride = blkNumPerColumn; + repeatParams.src0RepStride = blkNumPerColumn; + repeatParams.src1RepStride = 0; + + constexpr uint32_t rowNumPerCompute = maxRepeatTimes; + constexpr uint32_t colNumPerCompute = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); + for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; rowOffset += rowNumPerCompute) { + uint32_t residueM = TileShape::ROW - rowOffset; + uint8_t repeatTimes = static_cast((residueM > rowNumPerCompute) ? rowNumPerCompute : residueM); + for (uint32_t colOffset = 0; colOffset < TileShape::COLUMN; colOffset += colNumPerCompute) { + uint32_t residueN = TileShape::COLUMN - colOffset; + uint64_t mask = (residueN > colNumPerCompute) ? colNumPerCompute : residueN; + AscendC::Mul(ubOut[rowOffset * TileShape::COLUMN + colOffset], + ubIn0[rowOffset * TileShape::COLUMN + colOffset], ubIn1[colOffset], mask, repeatTimes, + repeatParams); + } + } + } +}; + +/// @brief Compute the elementwise multiplication of a tensor of shape (m, n) and a tensor of shape +/// (m, eleNumPerBlk), which is broadcast from a tensor of shape (m, 1), broadcast to (m, n). +/// @tparam ArchTag_ is the architecture tag. +/// @tparam ComputeType_ includes the element type and layout information. +/// @tparam TileShape_ is the shape (m, n). +template +struct TileOneBlkColumnBroadcastMul { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + CATLASS_DEVICE + TileOneBlkColumnBroadcastMul() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, + AscendC::LocalTensor const &ubIn0, + AscendC::LocalTensor const &ubIn1) + { + constexpr uint32_t maxRepeatNum = 255; + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + + constexpr uint32_t blkNumPerColumn = TileShape::COLUMN / eleNumPerBlk; + AscendC::BinaryRepeatParams repeatParams; + repeatParams.dstBlkStride = blkNumPerColumn; + repeatParams.src0BlkStride = blkNumPerColumn; + repeatParams.src1BlkStride = 1; + repeatParams.dstRepStride = 1; + repeatParams.src0RepStride = 1; + repeatParams.src1RepStride = 0; + + constexpr uint32_t rowNumPerCompute = BLK_NUM_PER_VECTOR_FRACTAL; + constexpr uint32_t colNumPerCompute = eleNumPerBlk * maxRepeatNum; + for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; rowOffset += rowNumPerCompute) { + uint32_t residueM = TileShape::ROW - rowOffset; + uint64_t mask = ((residueM > rowNumPerCompute) ? rowNumPerCompute : residueM) * eleNumPerBlk; + for (uint32_t colOffset = 0; colOffset < TileShape::COLUMN; colOffset += colNumPerCompute) { + uint32_t residueN = TileShape::COLUMN - colOffset; + uint8_t repeatTimes = + static_cast(((residueN > colNumPerCompute) ? colNumPerCompute : residueN) / eleNumPerBlk); + AscendC::Mul(ubOut[rowOffset * TileShape::COLUMN + colOffset], + ubIn0[rowOffset * TileShape::COLUMN + colOffset], ubIn1[rowOffset * eleNumPerBlk], mask, + repeatTimes, repeatParams); + } + } + } +}; + +} // namespace Catlass::Epilogue::Tile + +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_broadcast_one_blk.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_broadcast_one_blk.hpp new file mode 100644 index 00000000..923f9e29 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_broadcast_one_blk.hpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_TILE_TILE_BROADCAST_ONE_BLK_HPP +#define CATLASS_EPILOGUE_TILE_TILE_BROADCAST_ONE_BLK_HPP + +#include "catlass/catlass.hpp" + +namespace Catlass::Epilogue::Tile { + +template +struct TileBroadcastOneBlk { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; + + CATLASS_DEVICE + TileBroadcastOneBlk() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, AscendC::LocalTensor const &ubIn) + { + constexpr uint32_t maxRepeatNum = 255; + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + + AscendC::BrcbRepeatParams repeatParams; + repeatParams.dstBlkStride = 1; + repeatParams.dstRepStride = BLK_NUM_PER_VECTOR_FRACTAL; + + constexpr uint32_t eleNumPerCompute = RoundDown(maxRepeatNum * BLK_NUM_PER_VECTOR_FRACTAL); + for (uint32_t offset = 0; offset < COMPUTE_LENGTH; offset += eleNumPerCompute) { + uint32_t residueM = COMPUTE_LENGTH - offset; + uint32_t computeM = (residueM > eleNumPerCompute) ? eleNumPerCompute : residueM; + uint8_t repeatTimes = static_cast(CeilDiv(computeM)); + AscendC::Brcb(ubOut[offset * eleNumPerBlk], ubIn[offset], repeatTimes, repeatParams); + } + } +}; + +} // namespace Catlass::Epilogue::Tile + +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_cast.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_cast.hpp new file mode 100644 index 00000000..2d3b1ccb --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_cast.hpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_TILE_TILE_CAST_HPP +#define CATLASS_EPILOGUE_TILE_TILE_CAST_HPP + +#include "catlass/catlass.hpp" + +namespace Catlass::Epilogue::Tile { + +template < + /// Tag indicating architecture + class ArchTag_, + /// Compute data type + class DstType_, class SrcType_, + /// Length of the compute buffer + class TileShape_> +struct TileCast { + using ArchTag = ArchTag_; + using ElementDst = typename DstType_::Element; + using ElementSrc = typename SrcType_::Element; + using TileShape = TileShape_; + + CATLASS_DEVICE + TileCast() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, AscendC::LocalTensor const &ubIn) + { + AscendC::Cast(ubOut, ubIn, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + } +}; + +} // namespace Catlass::Epilogue::Tile + +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_copy.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_copy.hpp new file mode 100644 index 00000000..a553c6a5 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_copy.hpp @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_TILE_TILE_COPY_HPP +#define CATLASS_EPILOGUE_TILE_TILE_COPY_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/epilogue/tile/copy_gm_to_ub.hpp" +#include "catlass/epilogue/tile/copy_ub_to_gm.hpp" + +namespace Catlass::Epilogue::Tile { + +template < + /// Tag indicating architecture + class ArchTag, class... Args> +struct TileCopy { + static_assert(DEPENDENT_FALSE, "Unsupported tile copy, can not find the specialization."); +}; + +template +struct TileCopy { + using ElementC = typename CType::Element; + using ElementD = typename DType::Element; + + using CopyGmToUbC = CopyGm2Ub; + using CopyUbToGmD = CopyUb2Gm; +}; + +template +struct TileCopy { + using ElementC = typename CType::Element; + using ElementX = typename XType::Element; + using ElementD = typename DType::Element; + + using CopyGmToUbC = CopyGm2Ub; + using CopyGmToUbX = CopyGm2Ub; + using CopyUbToGmD = CopyUb2Gm; +}; + +template +struct TileCopy { + using ElementC = typename CType::Element; + using ElementX = typename XType::Element; + using ElementY = typename YType::Element; + using ElementD = typename DType::Element; + + using CopyGmToUbC = CopyGm2Ub; + using CopyGmToUbX = CopyGm2Ub; + using CopyGmToUbY = CopyGm2Ub; + using CopyUbToGmD = CopyUb2Gm; +}; + +template +struct TileCopyBf16 { + using ElementC = typename CType::Element; + using ElementX = bfloat16_t; + using ElementY = bfloat16_t; + using ElementD = bfloat16_t; + + using CopyGmToUbC = CopyGm2Ub; + using CopyGmToUbX = CopyGm2Ub>; + using CopyGmToUbY = CopyGm2Ub>; + using CopyUbToGmD = CopyUb2Gm>; +}; + +template +struct TileCopyPerTokenDequant { + using ElementC = typename CType::Element; + using ElementScale = typename ScaleType::Element; + using ElementPerTokenScale = typename PerTokenScaleType::Element; + using ElementD = typename DType::Element; + + using CopyGmToUbC = CopyGm2Ub; + using CopyGmToUbScale = CopyGm2Ub; + using CopyGmToUbPerTokenScale = CopyPerTokenScale2Ub; + using CopyUbToGmD = CopyUb2Gm; +}; +} // namespace Catlass::Epilogue::Tile + +#endif // CATLASS_EPILOGUE_TILE_TILE_COPY_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_elemwise_add.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_elemwise_add.hpp new file mode 100644 index 00000000..89a1341f --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_elemwise_add.hpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_TILE_TILE_ELEMWISE_ADD_HPP +#define CATLASS_EPILOGUE_TILE_TILE_ELEMWISE_ADD_HPP + +#include "catlass/catlass.hpp" + +namespace Catlass::Epilogue::Tile { + +template < + /// Tag indicating architecture + class ArchTag_, + /// Compute data type + class ComputeType_, + /// Length of the compute buffer + uint32_t COMPUTE_LENGTH_> +struct TileElemWiseAdd { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + + static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; + + CATLASS_DEVICE + TileElemWiseAdd() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, + AscendC::LocalTensor const &ubIn0, + AscendC::LocalTensor const &ubIn1) + { + // Do the calculation + AscendC::Add(ubOut, ubIn0, ubIn1, COMPUTE_LENGTH); + } +}; + +} // namespace Catlass::Epilogue::Tile + +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_elemwise_gelu.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_elemwise_gelu.hpp new file mode 100644 index 00000000..4b86b2af --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_elemwise_gelu.hpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_TILE_TILE_ELEMWISE_GELU_HPP +#define CATLASS_EPILOGUE_TILE_TILE_ELEMWISE_GELU_HPP + +#include "catlass/catlass.hpp" + +namespace Catlass::Epilogue::Tile { +template < + // / Tag indicating architecture + class ArchTag_, + // / Compute data type + class ComputeType_, + // / Length of the compute buffer + uint32_t COMPUTE_LENGTH_> +struct TileElemWiseGelu { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + + static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; + const float NEG_SQRT_EIGHT_OVER_PI = -1.595769121 * 0.044715; + const float TANH_APPROX_FACTOR = 1 / 0.044715; + + CATLASS_DEVICE + TileElemWiseGelu() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstLocal, + AscendC::LocalTensor const &srcLocal) + { + using namespace AscendC; + + // current realization: x / (1 + e^(-1.5957691*0.044715(x/0.044715 + x^3))) + Mul(dstLocal, srcLocal, srcLocal, COMPUTE_LENGTH); // d: x^2 , s:x + Mul(dstLocal, dstLocal, srcLocal, COMPUTE_LENGTH); // d: x^3 ,.s:x + Axpy(dstLocal, srcLocal, TANH_APPROX_FACTOR, COMPUTE_LENGTH); // d: x / 0.044715 + x^3 , s: x + // d: -1.5957691*0.044715(x/0.044715 + x^3), s: x + Muls(dstLocal, dstLocal, NEG_SQRT_EIGHT_OVER_PI, COMPUTE_LENGTH); + Exp(dstLocal, dstLocal, COMPUTE_LENGTH); // d: e^(-1.5957691*0.044715(x/0.044715 + x^3)) + // d: (1 + e^(-1.5957691*0.044715(x/0.044715 + x^3)) + Adds(dstLocal, dstLocal, (ElementCompute)1, COMPUTE_LENGTH); + Div(dstLocal, srcLocal, dstLocal, COMPUTE_LENGTH); + } +}; +} // namespace Catlass::Epilogue::Tile + +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_elemwise_mul.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_elemwise_mul.hpp new file mode 100644 index 00000000..08ad9862 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_elemwise_mul.hpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_TILE_TILE_ELEMWISE_MUL_HPP +#define CATLASS_EPILOGUE_TILE_TILE_ELEMWISE_MUL_HPP + +#include "catlass/catlass.hpp" + +namespace Catlass::Epilogue::Tile { + +template < + /// Tag indicating architecture + class ArchTag_, + /// Compute data type + class ComputeType_, + /// Length of the compute buffer + class TileShape_> +struct TileElemwiseMul { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + CATLASS_DEVICE + TileElemwiseMul() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, + AscendC::LocalTensor const &ubIn0, + AscendC::LocalTensor const &ubIn1) + { + // Do the calculation + AscendC::Mul(ubOut, ubIn0, ubIn1, TileShape::COUNT); + } +}; + +} // namespace Catlass::Epilogue::Tile + +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_elemwise_muls.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_elemwise_muls.hpp new file mode 100644 index 00000000..b1b5451e --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_elemwise_muls.hpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_TILE_TILE_ELEMWISE_MULS_HPP +#define CATLASS_EPILOGUE_TILE_TILE_ELEMWISE_MULS_HPP + +#include "catlass/gemm/helper.hpp" + +namespace Catlass::Epilogue::Tile { +template +struct TileElemWiseMuls { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + + static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; + + CATLASS_DEVICE + TileElemWiseMuls() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor dstLocal, AscendC::LocalTensor srcTensor, + ElementCompute scalar) + { + AscendC::Muls(dstLocal, srcTensor, scalar, COMPUTE_LENGTH); + } +}; +} // namespace Catlass::Epilogue::Tile + +#endif // CATLASS_EPILOGUE_TILE_TILE_ELEMWISE_MULS_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_elemwise_swish.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_elemwise_swish.hpp new file mode 100644 index 00000000..fc00ef79 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_elemwise_swish.hpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_TILE_TILE_ELEMWISE_SWISH_HPP +#define CATLASS_EPILOGUE_TILE_TILE_ELEMWISE_SWISH_HPP + +#include "catlass/catlass.hpp" + +namespace Catlass::Epilogue::Tile { +template < + // / Tag indicating architecture + class ArchTag_, + // / Compute data type + class ComputeType_, + // / COMPUTE_LENGTH of the compute buffer + uint32_t COMPUTE_COMPUTE_LENGTH_> +struct TileElemWiseSwish { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + + static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_COMPUTE_LENGTH_; + + CATLASS_DEVICE + TileElemWiseSwish() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstLocal, + AscendC::LocalTensor const &srcLocal) + { + using namespace AscendC; + // d: -x, s: x + Muls(dstLocal, srcLocal, (ElementCompute)-1, COMPUTE_LENGTH); + // d: exp(-x), s: x + Exp(dstLocal, dstLocal, COMPUTE_LENGTH); + // d: 1 + exp(-x), s: x + Adds(dstLocal, dstLocal, (ElementCompute)1, COMPUTE_LENGTH); + // d: x / 1 + exp(-x), s: x + Div(dstLocal, srcLocal, dstLocal, COMPUTE_LENGTH); + } +}; +} // namespace Catlass::Epilogue::Tile + +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_swizzle.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_swizzle.hpp new file mode 100644 index 00000000..b07a176d --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/epilogue/tile/tile_swizzle.hpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_EPILOGUE_TILE_TILE_SWIZZLE_HPP +#define CATLASS_EPILOGUE_TILE_TILE_SWIZZLE_HPP + +#include "catlass/catlass.hpp" +#include "catlass/detail/alignment.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Epilogue::Tile { + +struct EpilogueIdentityTileSwizzle { + MatrixCoord blockShape; + MatrixCoord tileShape; + MatrixCoord loopsMN; + + CATLASS_DEVICE + EpilogueIdentityTileSwizzle() = default; + + CATLASS_DEVICE + EpilogueIdentityTileSwizzle(MatrixCoord const &blockShape, MatrixCoord const &tileShape) + : blockShape(blockShape), tileShape(tileShape) + { + loopsMN = CeilDiv(blockShape, tileShape); + } + + CATLASS_DEVICE + uint32_t GetLoops() const + { + return loopsMN.row() * loopsMN.column(); + } + + CATLASS_DEVICE + MatrixCoord GetTileCoord(uint32_t loopIdx) const + { + return MatrixCoord{loopIdx / loopsMN.column(), loopIdx % loopsMN.column()}; + } + + CATLASS_DEVICE + MatrixCoord GetActualTileShape(MatrixCoord const &tileCoord) const + { + return MatrixCoord::Min(tileShape, blockShape - tileCoord * tileShape); + } +}; + +struct EpilogueHorizontalTileSwizzle { + MatrixCoord blockShape; + MatrixCoord tileShape; + MatrixCoord loopsMN; + + CATLASS_DEVICE + EpilogueHorizontalTileSwizzle() = default; + + CATLASS_DEVICE + EpilogueHorizontalTileSwizzle(MatrixCoord const &blockShape, MatrixCoord const &tileShape) + : blockShape(blockShape), tileShape(tileShape) + { + loopsMN = CeilDiv(blockShape, tileShape); + } + + CATLASS_DEVICE + uint32_t GetLoops() const + { + return loopsMN.row() * loopsMN.column(); + } + + CATLASS_DEVICE + MatrixCoord GetTileCoord(uint32_t loopIdx) const + { + return MatrixCoord{loopIdx % loopsMN.row(), loopIdx / loopsMN.row()}; + } + + CATLASS_DEVICE + MatrixCoord GetActualTileShape(MatrixCoord const &tileCoord) const + { + return MatrixCoord::Min(tileShape, blockShape - tileCoord * tileShape); + } +}; + +} // namespace Catlass::Epilogue::Tile + +#endif // CATLASS_EPILOGUE_TILE_TILE_SWIZZLE_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_dequant.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_dequant.hpp new file mode 100644 index 00000000..c23fdd94 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_dequant.hpp @@ -0,0 +1,346 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_DEQUANT_HPP +#define CATLASS_GEMM_BLOCK_DEQUANT_HPP + +#include "catlass/catlass.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/epilogue/tile/copy_gm_to_ub.hpp" +#include "catlass/epilogue/tile/copy_ub_to_gm.hpp" + +namespace Catlass::Gemm::Block { + +template +struct DequantFP8toFP16 { +public: + using ArchTag = ArchTag_; + using ElementIn = Element_; + using LayoutIn = LayoutIn_; + using CopyGm2Ub = Catlass::Epilogue::Tile::CopyGm2Ub>; + using CopyUb2Gm = Catlass::Epilogue::Tile::CopyUb2Gm>; + + using CopyGm2UbFP32 = Catlass::Epilogue::Tile::CopyGm2Ub>; + using CopyUb2GmFP32 = Catlass::Epilogue::Tile::CopyUb2Gm>; + using LayoutC = Catlass::layout::RowMajor; + + static const uint32_t Alignment = 256; + + CopyGm2Ub copyGm2Ub; + CopyUb2Gm copyUb2Gm; + CopyGm2UbFP32 copyGm2UbFP32; + CopyUb2GmFP32 copyUb2GmFP32; + + struct AiCoreInfo { + uint32_t AivNum; + uint32_t AivId; + } aiCoreInfo; + + struct BlockLoopInfo { + uint32_t m; + uint32_t n; + uint32_t aivId; + uint32_t aivNum; + uint64_t srcBlockOffset; + uint64_t dstBlockOffset; + uint32_t totalLoop; + uint32_t nLoop; + uint32_t taskPerAiv; + }; + + struct LoadStoreInfo { + uint32_t loopIdx; + uint32_t mIdx; + uint32_t nIdx; + uint64_t srcProcessOffset; + uint64_t dstProcessOffset; + // loader params + uint32_t loadRepeat = 1; + uint32_t loadLen; + uint32_t srcLoadStride = 0; + uint32_t dstLoadStride = 0; + // storer params + uint32_t storeRepeat = 1; + uint32_t storeLen; + uint32_t srcStoreStride = 0; + uint32_t dstStoreStride; + }; + + CATLASS_DEVICE + DequantFP8toFP16() {} + + CATLASS_DEVICE + DequantFP8toFP16(Arch::Resource &resource) + { + int64_t bufferOffset = 0; + for (uint32_t i = 0; i < BUFFER_NUM; i++) { + inputBuffer[i] = resource.ubBuf.template GetBufferByByte(bufferOffset * sizeof(ElementIn)); + bufferOffset += COMPUTE_LENGTH; + } + for (uint32_t i = 0; i < BUFFER_NUM; i++) { + outputBuffer[i] = (resource.ubBuf.template GetBufferByByte(bufferOffset * sizeof(ElementIn))) + .template ReinterpretCast(); + bufferOffset += COMPUTE_LENGTH * 2; + } + for (uint32_t i = 0; i < BUFFER_NUM; i++) { + workspace[i] = (resource.ubBuf.template GetBufferByByte(bufferOffset * sizeof(ElementIn))) + .template ReinterpretCast(); + bufferOffset += COMPUTE_LENGTH * 2; + } + int16_t value_uint = 0x4000; + value_vector1 = (resource.ubBuf.template GetBufferByByte(bufferOffset * sizeof(ElementIn))) + .template ReinterpretCast(); + bufferOffset += 256; + AscendC::Duplicate(value_vector1, value_uint, 128); + pipe_barrier(PIPE_V); + value_uint = 0x3FFF; + value_vector2 = (resource.ubBuf.template GetBufferByByte(bufferOffset * sizeof(ElementIn))) + .template ReinterpretCast(); + bufferOffset += 256; + AscendC::Duplicate(value_vector2, value_uint, 128); + pipe_barrier(PIPE_V); + } + + CATLASS_DEVICE + void GetBlockLoopInfo(BlockLoopInfo &blockLoopInfo, uint32_t srcStride, uint32_t dstStride) + { + blockLoopInfo.taskPerAiv = blockLoopInfo.m / blockLoopInfo.aivNum; + uint32_t taskRemain = blockLoopInfo.m % blockLoopInfo.aivNum; + if (blockLoopInfo.aivId < taskRemain) { + blockLoopInfo.taskPerAiv++; + } + + uint32_t alignedN = RoundUp(blockLoopInfo.n); + blockLoopInfo.srcBlockOffset = blockLoopInfo.aivId * blockLoopInfo.taskPerAiv * srcStride; + blockLoopInfo.dstBlockOffset = blockLoopInfo.aivId * blockLoopInfo.taskPerAiv * dstStride; + if (blockLoopInfo.aivId >= taskRemain) { + blockLoopInfo.srcBlockOffset += taskRemain * srcStride; + blockLoopInfo.dstBlockOffset += taskRemain * dstStride; + } + if (alignedN > COMPUTE_LENGTH / 2) { + blockLoopInfo.nLoop = (blockLoopInfo.n + COMPUTE_LENGTH - 1) / COMPUTE_LENGTH; + blockLoopInfo.totalLoop = blockLoopInfo.taskPerAiv * blockLoopInfo.nLoop; + } else if (alignedN != 0) { + blockLoopInfo.nLoop = COMPUTE_LENGTH / alignedN; + blockLoopInfo.totalLoop = (blockLoopInfo.taskPerAiv + blockLoopInfo.nLoop - 1) / blockLoopInfo.nLoop; + } else { + blockLoopInfo.nLoop = 0; + blockLoopInfo.totalLoop = 0; + } + } + + CATLASS_DEVICE + void GetLoaderStorerInfo(BlockLoopInfo &blockLoopInfo, LoadStoreInfo &loadStoreInfo, uint32_t srcStride, + uint32_t dstStride) + { + loadStoreInfo.loadLen = COMPUTE_LENGTH; + uint32_t alignedN = RoundUp(blockLoopInfo.n); + if (alignedN > COMPUTE_LENGTH / 2) { + loadStoreInfo.mIdx = loadStoreInfo.loopIdx / blockLoopInfo.nLoop; + loadStoreInfo.nIdx = loadStoreInfo.loopIdx % blockLoopInfo.nLoop; + + loadStoreInfo.srcProcessOffset = + blockLoopInfo.srcBlockOffset + loadStoreInfo.mIdx * srcStride + loadStoreInfo.nIdx * COMPUTE_LENGTH; + loadStoreInfo.dstProcessOffset = + blockLoopInfo.dstBlockOffset + loadStoreInfo.mIdx * dstStride + loadStoreInfo.nIdx * COMPUTE_LENGTH; + if ((loadStoreInfo.nIdx == blockLoopInfo.nLoop - 1) && (blockLoopInfo.n % COMPUTE_LENGTH != 0)) { + loadStoreInfo.loadLen = blockLoopInfo.n % COMPUTE_LENGTH; + } + loadStoreInfo.storeLen = loadStoreInfo.loadLen; + } else { + loadStoreInfo.mIdx = loadStoreInfo.loopIdx * blockLoopInfo.nLoop; + loadStoreInfo.srcProcessOffset = blockLoopInfo.srcBlockOffset + loadStoreInfo.mIdx * srcStride; + loadStoreInfo.dstProcessOffset = blockLoopInfo.dstBlockOffset + loadStoreInfo.mIdx * dstStride; + loadStoreInfo.loadLen = blockLoopInfo.n; + loadStoreInfo.loadRepeat = blockLoopInfo.nLoop; + loadStoreInfo.storeLen = blockLoopInfo.n; + loadStoreInfo.storeRepeat = blockLoopInfo.nLoop; + loadStoreInfo.dstStoreStride = dstStride; + if ((loadStoreInfo.loopIdx == blockLoopInfo.totalLoop - 1) && + (blockLoopInfo.taskPerAiv % blockLoopInfo.nLoop != 0)) { + loadStoreInfo.storeRepeat = blockLoopInfo.taskPerAiv % blockLoopInfo.nLoop; + loadStoreInfo.loadRepeat = loadStoreInfo.storeRepeat; + } + } + loadStoreInfo.srcLoadStride = srcStride; + loadStoreInfo.dstLoadStride = alignedN; + loadStoreInfo.srcStoreStride = alignedN; + } + + CATLASS_DEVICE + void Dequant(AscendC::LocalTensor &src, AscendC::LocalTensor &dst, + AscendC::LocalTensor &value_vector1, AscendC::LocalTensor &value_vector2, + AscendC::LocalTensor &workspace, half scalar, half zeroPoint) + { + pipe_barrier(PIPE_V); + uint32_t num = COMPUTE_LENGTH; + num = (num + 128 - 1) / 128 * 128; + AscendC::Cast(dst.template ReinterpretCast(), src.template ReinterpretCast(), + AscendC::RoundMode::CAST_NONE, num); + pipe_barrier(PIPE_V); + + AscendC::Adds(dst, dst, 1024, num); + pipe_barrier(PIPE_V); + + AscendC::ShiftLeft(dst.template ReinterpretCast(), dst.template ReinterpretCast(), + 7, num); + pipe_barrier(PIPE_V); + + uint64_t mask = 128; + AscendC::And(workspace.template ReinterpretCast(), dst.template ReinterpretCast(), + value_vector1, mask, num / 128, {1, 1, 1, 8, 8, 0}); + pipe_barrier(PIPE_V); + + AscendC::ShiftLeft(workspace.template ReinterpretCast(), + workspace.template ReinterpretCast(), 1, num); + pipe_barrier(PIPE_V); + + AscendC::And(dst.template ReinterpretCast(), dst.template ReinterpretCast(), + value_vector2, mask, num / 128, {1, 1, 1, 8, 8, 0}); + pipe_barrier(PIPE_V); + + AscendC::Or(dst.template ReinterpretCast(), dst.template ReinterpretCast(), + workspace.template ReinterpretCast(), num); + pipe_barrier(PIPE_V); + + AscendC::Muls(dst.template ReinterpretCast(), dst.template ReinterpretCast(), 1 << 8, num); + pipe_barrier(PIPE_V); + + AscendC::Adds(dst, dst, zeroPoint, num); + pipe_barrier(PIPE_V); + + AscendC::Muls(dst, dst, scalar, num); + pipe_barrier(PIPE_V); + } + + CATLASS_DEVICE + void castFP32toFP16(AscendC::GlobalTensor src, AscendC::GlobalTensor dst, LayoutC layout, + uint32_t srcStride, uint32_t dstStride) + { + AscendC::LocalTensor input[BUFFER_NUM]; + AscendC::LocalTensor output[BUFFER_NUM]; + + Arch::Resource resource; + int64_t bufferOffset = 0; + const int64_t CAST_LENGTH = 32 * 1024 / sizeof(half); // 一次处理16K个数据 + for (int i = 0; i < BUFFER_NUM; i++) { + input[i] = resource.ubBuf.template GetBufferByByte(bufferOffset); + bufferOffset += CAST_LENGTH * 4; // float 4字节 + } + for (int i = 0; i < BUFFER_NUM; i++) { + output[i] = resource.ubBuf.template GetBufferByByte(bufferOffset); + bufferOffset += CAST_LENGTH * 2; // half 2字节 + } + + BlockLoopInfo blockLoopInfo; + blockLoopInfo.m = layout.shape(0); + blockLoopInfo.n = layout.shape(1); + + blockLoopInfo.aivNum = 2; + blockLoopInfo.aivId = AscendC::GetSubBlockIdx(); + + GetBlockLoopInfo(blockLoopInfo, srcStride, dstStride); + for (int ldx = 0; ldx < blockLoopInfo.totalLoop; ldx++) { + LoadStoreInfo loadStoreInfo; + loadStoreInfo.loopIdx = ldx; + GetLoaderStorerInfo(blockLoopInfo, loadStoreInfo, srcStride, dstStride); + AscendC::WaitFlag(EventIdBufferForCast[bufferIndexForCast]); + + auto layoutSrcIn = + layout::RowMajor(loadStoreInfo.loadRepeat, loadStoreInfo.loadLen, loadStoreInfo.srcLoadStride); + auto layoutDstIn = + layout::RowMajor(loadStoreInfo.loadRepeat, loadStoreInfo.loadLen, loadStoreInfo.dstLoadStride); + copyGm2UbFP32(input[bufferIndexForCast], src[loadStoreInfo.srcProcessOffset], layoutDstIn, layoutSrcIn); + + AscendC::SetFlag(EventIdBufferForCast[bufferIndexForCast]); + AscendC::WaitFlag(EventIdBufferForCast[bufferIndexForCast]); + AscendC::WaitFlag(EventIdBufferForCast[bufferIndexForCast]); + AscendC::Cast(output[bufferIndexForCast], input[bufferIndexForCast], AscendC::RoundMode::CAST_RINT, + CAST_LENGTH); + AscendC::SetFlag(EventIdBufferForCast[bufferIndexForCast]); + AscendC::WaitFlag(EventIdBufferForCast[bufferIndexForCast]); + AscendC::SetFlag(EventIdBufferForCast[bufferIndexForCast]); + + auto layoutSrcOut = + layout::RowMajor(loadStoreInfo.storeRepeat, loadStoreInfo.storeLen, loadStoreInfo.srcStoreStride); + auto layoutDstOut = + layout::RowMajor(loadStoreInfo.storeRepeat, loadStoreInfo.storeLen, loadStoreInfo.dstStoreStride); + copyUb2Gm(dst[loadStoreInfo.dstProcessOffset], output[bufferIndexForCast], layoutDstOut, layoutSrcOut); + AscendC::SetFlag(EventIdBufferForCast[bufferIndexForCast]); + bufferIndexForCast = (bufferIndexForCast + 1) % BUFFER_NUM; + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor src, AscendC::GlobalTensor dst, LayoutIn layout, + uint32_t srcStride, uint32_t dstStride, half scalar, half zeroPoint, uint32_t &bufferIndex) + { + BlockLoopInfo blockLoopInfo; + blockLoopInfo.m = layout.shape(0); + blockLoopInfo.n = layout.shape(1); + if (std::is_same::value) { + blockLoopInfo.m = layout.shape(1); + blockLoopInfo.n = layout.shape(0); + } + + blockLoopInfo.aivNum = 2; + blockLoopInfo.aivId = AscendC::GetSubBlockIdx(); + + GetBlockLoopInfo(blockLoopInfo, srcStride, dstStride); + for (int ldx = 0; ldx < blockLoopInfo.totalLoop; ldx++) { + LoadStoreInfo loadStoreInfo; + loadStoreInfo.loopIdx = ldx; + GetLoaderStorerInfo(blockLoopInfo, loadStoreInfo, srcStride, dstStride); + AscendC::WaitFlag(EventIdBuffer[bufferIndex]); + + auto layoutSrcIn = + layout::RowMajor(loadStoreInfo.loadRepeat, loadStoreInfo.loadLen, loadStoreInfo.srcLoadStride); + auto layoutDstIn = + layout::RowMajor(loadStoreInfo.loadRepeat, loadStoreInfo.loadLen, loadStoreInfo.dstLoadStride); + copyGm2Ub(inputBuffer[bufferIndex], src[loadStoreInfo.srcProcessOffset], layoutDstIn, layoutSrcIn); + + AscendC::SetFlag(EventIdBuffer[bufferIndex]); + AscendC::WaitFlag(EventIdBuffer[bufferIndex]); + AscendC::WaitFlag(EventIdBuffer[bufferIndex]); + Dequant(inputBuffer[bufferIndex], outputBuffer[bufferIndex], value_vector1, value_vector2, + workspace[bufferIndex], scalar, zeroPoint); + AscendC::SetFlag(EventIdBuffer[bufferIndex]); + AscendC::WaitFlag(EventIdBuffer[bufferIndex]); + AscendC::SetFlag(EventIdBuffer[bufferIndex]); + + auto layoutSrcOut = + layout::RowMajor(loadStoreInfo.storeRepeat, loadStoreInfo.storeLen, loadStoreInfo.srcStoreStride); + auto layoutDstOut = + layout::RowMajor(loadStoreInfo.storeRepeat, loadStoreInfo.storeLen, loadStoreInfo.dstStoreStride); + copyUb2Gm(dst[loadStoreInfo.dstProcessOffset], outputBuffer[bufferIndex], layoutDstOut, layoutSrcOut); + AscendC::SetFlag(EventIdBuffer[bufferIndex]); + bufferIndex = (bufferIndex + 1) % BUFFER_NUM; + } + } + +private: + static const uint32_t BUFFER_NUM = 2; + AscendC::LocalTensor inputBuffer[BUFFER_NUM]; + AscendC::LocalTensor value_vector1; + AscendC::LocalTensor value_vector2; + AscendC::LocalTensor outputBuffer[BUFFER_NUM]; + AscendC::LocalTensor workspace[BUFFER_NUM]; + AscendC::TEventID EventIdBuffer[BUFFER_NUM] = {EVENT_ID0, EVENT_ID1}; + AscendC::TEventID EventIdBufferForCast[BUFFER_NUM] = {EVENT_ID2, EVENT_ID3}; + uint32_t bufferIndexForCast{0}; +}; + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_DEQUANT_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad.hpp new file mode 100644 index 00000000..7e7b7910 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad.hpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_HPP + +#include "catlass/catlass.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +namespace Catlass::Gemm::Block { + +template , + class TileMmad = Gemm::Tile::TileMmad > +struct BlockMmad { + static_assert(DEPENDENT_FALSE, "BlockMmad is not implemented for this DispatchPolicy"); +}; + +template , + class TileMmad = Gemm::Tile::TileMmadTla > +struct BlockMmadTla { + static_assert(DEPENDENT_FALSE, "BlockMmadTla is not implemented for this DispatchPolicy"); +}; + +/// new add for the reason that i am using the dispatchpolicy which is same as the policy of the optimized_matmul +// so i add a new one class to avoid the conflict +template , // change the name + class TileMmad = Gemm::Tile::TileMmad > +struct BlockGemm { + static_assert(DEPENDENT_FALSE, "BlockMmad is not implemented for this DispatchPolicy"); +}; + +} // namespace Catlass::Gemm::Block + +#include "catlass/gemm/block/block_mmad_pingpong.hpp" +#include "catlass/gemm/block/block_mmad_fa_qk.hpp" +#include "catlass/gemm/block/block_mmad_fa_pv.hpp" +#include "catlass/gemm/block/block_mmad_mla_qk.hpp" +#include "catlass/gemm/block/block_mmad_mla_pv.hpp" +#include "catlass/gemm/block/block_mmad_mla_qk_tp1_spec.hpp" +#include "catlass/gemm/block/block_mmad_mla_pv_tp1_spec.hpp" +#include "catlass/gemm/block/block_mmad_preload.hpp" +#include "catlass/gemm/block/block_mmad_preload_async.hpp" +#include "catlass/gemm/block/block_mmad_pingpong_tla.hpp" +#include "catlass/gemm/block/block_mmad_preload_tla.hpp" +#include "catlass/gemm/block/block_mmad_preload_async_with_callback.hpp" +#include "catlass/gemm/block/block_mmad_gemm.hpp" +#include "catlass/gemm/block/block_mmad_pingpong_bias.hpp" +#include "catlass/gemm/block/block_mmad_fai_qk_head_tail.hpp" +#include "catlass/gemm/block/block_mmad_fai_qk_normal.hpp" +#include "catlass/gemm/block/block_mmad_fai_pv_head_tail.hpp" +#include "catlass/gemm/block/block_mmad_fai_pv_normal.hpp" +#include "catlass/gemm/block/block_mmad_pingpong_full_loadA.hpp" +#include "catlass/gemm/block/block_mmad_w8a16.hpp" +#include "catlass/gemm/block/block_mmad_pingpong_slice_k.hpp" +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_fa_pv.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_fa_pv.hpp new file mode 100644 index 00000000..f3686746 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_fa_pv.hpp @@ -0,0 +1,173 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_FA_PV_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_FA_PV_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +//////////////////////////////////////////////////////////////////// + +namespace Catlass::Gemm::Block { +//////////////////////////////////////////////////////////////////// + +template +struct BlockMmad { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2FAPV; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; + static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; + static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; + static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; + static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES; + static constexpr uint32_t L0C_PINGPONG_BUF_SIZE = L0C_SIZE / STAGES; + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + /// Construct + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + for (uint32_t i = 0; i < STAGES; i++) { + l1ATensor[i] = resource.l1Buf.template GetBufferByByte(l1BufAddrStart + L1A_SIZE * i); + l1BTensor[i] = + resource.l1Buf.template GetBufferByByte(l1BufAddrStart + L1A_SIZE * STAGES + L1B_SIZE * i); + l0ATensor[i] = resource.l0ABuf.template GetBufferByByte(L0A_PINGPONG_BUF_SIZE * i); + l0BTensor[i] = resource.l0BBuf.template GetBufferByByte(L0B_PINGPONG_BUF_SIZE * i); + l0CTensor[i] = resource.l0CBuf.template GetBufferByByte(L0C_PINGPONG_BUF_SIZE * i); + } + + AscendC::SetFlag(EVENT_ID2); + AscendC::SetFlag(EVENT_ID3); + } + + /// Destructor + CATLASS_DEVICE + ~BlockMmad() + { + AscendC::WaitFlag(EVENT_ID2); + AscendC::WaitFlag(EVENT_ID3); + } + + /// Perform a block-scoped matrix multiply-accumulate + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gA, AscendC::GlobalTensor gB, + AscendC::GlobalTensor gC, LayoutA layoutA, LayoutB layoutB, LayoutC layoutC, + GemmCoord actualShape, uint32_t &pingpongFlag, Arch::CrossCoreFlag softmaxReady) + { + LayoutAInL1 layoutAInL1 = LayoutAInL1::template MakeLayout(L1TileShape::M, L1TileShape::K); + LayoutBInL1 layoutBInL1 = LayoutBInL1::template MakeLayout(L1TileShape::K, L1TileShape::N); + + uint32_t mRound = RoundUp(actualShape.m()); + uint32_t nRound = RoundUp(actualShape.n()); + uint32_t kRound = RoundUp(actualShape.k()); + + constexpr uint32_t PINGPONG_FLAG_OFFSET = 2; + AscendC::WaitFlag(pingpongFlag + PINGPONG_FLAG_OFFSET); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(actualShape.k(), actualShape.n())); + copyGmToL1B(l1BTensor[pingpongFlag], gB, layoutBInL1, layoutTileB); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + AscendC::WaitFlag(pingpongFlag); + LayoutBInL0 layoutBInL0 = LayoutBInL0::template MakeLayout(kRound, nRound); + copyL1ToL0B(l0BTensor[pingpongFlag], l1BTensor[pingpongFlag], layoutBInL0, layoutBInL1); + + Arch::CrossCoreWaitFlag(softmaxReady); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), actualShape.k())); + copyGmToL1A(l1ATensor[pingpongFlag], gA, layoutAInL1, layoutTileA); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + LayoutAInL0 layoutAInL0 = LayoutAInL0::template MakeLayout(mRound, kRound); + copyL1ToL0A(l0ATensor[pingpongFlag], l1ATensor[pingpongFlag], layoutAInL0, layoutAInL1); + AscendC::SetFlag(pingpongFlag + PINGPONG_FLAG_OFFSET); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + AscendC::WaitFlag(pingpongFlag); + tileMmad(l0CTensor[pingpongFlag], l0ATensor[pingpongFlag], l0BTensor[pingpongFlag], mRound, nRound, + actualShape.k()); + AscendC::SetFlag(pingpongFlag); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + // copy block out + auto blockShape = MakeCoord(actualShape.m(), actualShape.n()); + auto layoutInL0C = LayoutCInL0::MakeLayoutInL0C(blockShape); + + copyL0CToGm(gC, l0CTensor[pingpongFlag], layoutC, layoutInL0C); + AscendC::SetFlag(pingpongFlag); + + pingpongFlag = 1 - pingpongFlag; + } + +protected: + /// Data members + AscendC::LocalTensor l1ATensor[STAGES]; + AscendC::LocalTensor l1BTensor[STAGES]; + AscendC::LocalTensor l0ATensor[STAGES]; + AscendC::LocalTensor l0BTensor[STAGES]; + AscendC::LocalTensor l0CTensor[STAGES]; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +//////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_FA_PV_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_fa_qk.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_fa_qk.hpp new file mode 100644 index 00000000..8c7c8fce --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_fa_qk.hpp @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_FA_QK_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_FA_QK_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +//////////////////////////////////////////////////////////////////// + +namespace Catlass::Gemm::Block { +//////////////////////////////////////////////////////////////////// + +template +struct BlockMmad { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2FAQK; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; + static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; + static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; + static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; + static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES; + static constexpr uint32_t L0C_PINGPONG_BUF_SIZE = L0C_SIZE / STAGES; + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + /// Construct + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + for (uint32_t i = 0; i < STAGES; i++) { + l1ATensor[i] = resource.l1Buf.template GetBufferByByte(l1BufAddrStart + L1A_SIZE * i); + l1BTensor[i] = + resource.l1Buf.template GetBufferByByte(l1BufAddrStart + L1A_SIZE * STAGES + L1B_SIZE * i); + l0ATensor[i] = resource.l0ABuf.template GetBufferByByte(L0A_PINGPONG_BUF_SIZE * i); + l0BTensor[i] = resource.l0BBuf.template GetBufferByByte(L0B_PINGPONG_BUF_SIZE * i); + l0CTensor[i] = resource.l0CBuf.template GetBufferByByte(L0C_PINGPONG_BUF_SIZE * i); + } + + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + } + + /// Destructor + CATLASS_DEVICE + ~BlockMmad() + { + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + } + + /// Perform a block-scoped matrix multiply-accumulate + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gA, AscendC::GlobalTensor gB, + AscendC::GlobalTensor gC, LayoutA layoutA, LayoutB layoutB, LayoutC layoutC, + GemmCoord actualShape, uint32_t &pingpongFlag, bool isFirst) + { + LayoutAInL1 layoutAInL1 = LayoutAInL1::template MakeLayout(L1TileShape::M, L1TileShape::K); + LayoutBInL1 layoutBInL1 = LayoutBInL1::template MakeLayout(L1TileShape::K, L1TileShape::N); + + uint32_t mRound = RoundUp(actualShape.m()); + uint32_t nRound = RoundUp(actualShape.n()); + uint32_t kRound = RoundUp(actualShape.k()); + + if (isFirst) { + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), actualShape.k())); + copyGmToL1A(l1ATensor[0], gA, layoutAInL1, layoutTileA); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + + AscendC::WaitFlag(pingpongFlag); + LayoutAInL0 layoutAInL0 = LayoutAInL0::template MakeLayout(mRound, kRound); + copyL1ToL0A(l0ATensor[pingpongFlag], l1ATensor[0], layoutAInL0, layoutAInL1); + + AscendC::WaitFlag(pingpongFlag); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(actualShape.k(), actualShape.n())); + copyGmToL1B(l1BTensor[pingpongFlag], gB, layoutBInL1, layoutTileB); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + LayoutBInL0 layoutBInL0 = LayoutBInL0::template MakeLayout(kRound, nRound); + copyL1ToL0B(l0BTensor[pingpongFlag], l1BTensor[pingpongFlag], layoutBInL0, layoutBInL1); + AscendC::SetFlag(pingpongFlag); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + AscendC::WaitFlag(pingpongFlag); + tileMmad(l0CTensor[pingpongFlag], l0ATensor[pingpongFlag], l0BTensor[pingpongFlag], mRound, nRound, + actualShape.k()); + AscendC::SetFlag(pingpongFlag); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + // copy block out + auto blockShape = MakeCoord(actualShape.m(), actualShape.n()); + auto layoutInL0C = LayoutCInL0::MakeLayoutInL0C(blockShape); + + copyL0CToGm(gC, l0CTensor[pingpongFlag], layoutC, layoutInL0C); + AscendC::SetFlag(pingpongFlag); + + pingpongFlag = 1 - pingpongFlag; + } + +protected: + /// Data members + AscendC::LocalTensor l1ATensor[STAGES]; + AscendC::LocalTensor l1BTensor[STAGES]; + AscendC::LocalTensor l0ATensor[STAGES]; + AscendC::LocalTensor l0BTensor[STAGES]; + AscendC::LocalTensor l0CTensor[STAGES]; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +//////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_FA_QK_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_fai_pv_head_tail.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_fai_pv_head_tail.hpp new file mode 100644 index 00000000..0727110d --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_fai_pv_head_tail.hpp @@ -0,0 +1,300 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_MMAD_PV_TAIL_HPP +#define CATLASS_GEMM_BLOCK_MMAD_PV_TAIL_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +//////////////////////////////////////////////////////////////////// + +namespace Catlass::Gemm::Block { +//////////////////////////////////////////////////////////////////// + +template +struct BlockMmad, L1TileShape_, L0TileShape_, AType_, BType_, + CType_, BiasType_, TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2FAITailPV; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; + static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; + static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; + static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; + static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES; + static constexpr uint32_t L0C_PINGPONG_BUF_SIZE = L0C_SIZE / STAGES; + static constexpr uint32_t BLOCK_SIZE = 16; + static constexpr uint32_t EMBED_SPLIT_SIZE = 128; + static constexpr uint32_t UNIT_BLOCK_STACK_NUM = 4; + static constexpr uint32_t KV_BASE_BLOCK = 512; + static constexpr uint32_t KV_SPLIT_SIZE = 128; + static constexpr uint32_t LOAB_BLOCK = 1; + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + /// Construct + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + // Allocate L1 memory space + l0ATensor = resource.l0ABuf.template GetBufferByByte(0); + l0BTensor = resource.l0BBuf.template GetBufferByByte(0); + for (uint32_t i = 0; i < STAGES; i++) { + l1ATensor[i] = + resource.l1Buf.template GetBufferByByte(l1BufAddrStart + L1A_SIZE * LOAB_BLOCK * i); + l0CTensor[i] = resource.l0CBuf.template GetBufferByByte(L0C_PINGPONG_BUF_SIZE * i); + l1BTensor[i] = + resource.l1Buf.template GetBufferByByte(l1BufAddrStart + L1A_SIZE * 2 * 2 + L1B_SIZE * i); + } + } + + /// Destructor + CATLASS_DEVICE + ~BlockMmad() {} + + CATLASS_DEVICE + void getBlockShape(GemmCoord &actualShape, uint32_t &nowNIdx, uint32_t &kIdx, uint32_t &nLoop, uint32_t &kLoop, + uint32_t &kvSeqlen, uint32_t &embed, bool firstBlock, uint32_t maskTailS = 0) + { + uint32_t nSplitSize = KV_SPLIT_SIZE * LOAB_BLOCK; + uint32_t embedSplitSize = EMBED_SPLIT_SIZE; + if (nowNIdx + LOAB_BLOCK > nLoop - 1) { + nSplitSize = kvSeqlen - nowNIdx * KV_SPLIT_SIZE; + } + if (firstBlock && maskTailS != 0) { + nSplitSize = nSplitSize - maskTailS; + } + if (kIdx == kLoop - 1) { + embedSplitSize = embed - kIdx * EMBED_SPLIT_SIZE; + } + actualShape[1] = embedSplitSize; + actualShape[2] = nSplitSize; + } + + CATLASS_DEVICE + void getKVOffset(AscendC::GlobalTensor &gBlockTable, uint32_t &kOffset, uint32_t &nowNIdx, uint32_t &kIdx, + uint32_t &nLoop, uint32_t &kLoop, uint32_t &strideKV, uint32_t &blockSize, uint32_t maskTailS = 0) + { + if (nowNIdx >= nLoop || kIdx >= kLoop) { + kOffset = 0; + } + if constexpr (PAGED_CACHE_FLAG_) { + uint32_t blockTableId = gBlockTable.GetValue(nowNIdx); + kOffset = blockTableId * blockSize * strideKV + maskTailS * strideKV + kIdx * EMBED_SPLIT_SIZE; + } else { + kOffset = nowNIdx * KV_SPLIT_SIZE * strideKV + kIdx * EMBED_SPLIT_SIZE; + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gA, AscendC::GlobalTensor gB, + AscendC::GlobalTensor gC, AscendC::GlobalTensor gBlockTable, LayoutA layoutA, + LayoutB layoutB, GemmCoord actualOriShape, uint32_t &nIdx, uint32_t &nLoop, uint32_t &blockSize, + uint32_t kvSeqlen, uint32_t strideKV, Arch::CrossCoreFlag softmaxFlag, uint32_t maskTailS, + bool preloadFlag) + { + uint32_t embed = actualOriShape[1]; + uint32_t kLoop = CeilDiv(embed); + uint32_t rowNum = layoutA.shape(0); + uint32_t blockN = layoutA.shape(1); + GemmCoord actualShape{rowNum, 0, 0}; + GemmCoord actualNextShape{rowNum, 0, 0}; + uint32_t nkBlockLoop = (nLoop + LOAB_BLOCK - 1) / LOAB_BLOCK * kLoop; // gap + uint32_t nkBlockNextIdx = (nIdx + LOAB_BLOCK - 1) / LOAB_BLOCK * kLoop + 1; // gap + uint32_t gBOffset = 0; + uint32_t gBNextOffset = 0; + uint32_t nowMaskTailS = 0; + uint32_t gPOffset = 0; + for (uint32_t kIdx = 0; kIdx < kLoop; kIdx++) { + nowMaskTailS = maskTailS; + gPOffset = 0; + for (uint32_t blockStackIdx = 0; (blockStackIdx < UNIT_BLOCK_STACK_NUM) && ((nIdx + blockStackIdx) < nLoop); + blockStackIdx += LOAB_BLOCK) { + uint32_t nowNIdx = nIdx + blockStackIdx; + uint32_t kLoopNextIdx = + (nkBlockNextIdx % (kLoop * UNIT_BLOCK_STACK_NUM)) / (UNIT_BLOCK_STACK_NUM / LOAB_BLOCK); + uint32_t nLoopNextIdx = + (nkBlockNextIdx % (kLoop * UNIT_BLOCK_STACK_NUM)) % (UNIT_BLOCK_STACK_NUM / LOAB_BLOCK) + + nkBlockNextIdx / (kLoop * UNIT_BLOCK_STACK_NUM) * UNIT_BLOCK_STACK_NUM; + uint32_t startSeqOffset = nowNIdx == nIdx ? maskTailS : 0; + uint32_t startSeqNxtOffset = nLoopNextIdx == nIdx ? maskTailS : 0; + getBlockShape(actualShape, nowNIdx, kIdx, nLoop, kLoop, kvSeqlen, embed, nowNIdx == nIdx, nowMaskTailS); + getBlockShape(actualNextShape, nLoopNextIdx, kLoopNextIdx, nLoop, kLoop, kvSeqlen, embed, + nLoopNextIdx == nIdx, nowMaskTailS); + getKVOffset(gBlockTable, gBOffset, nowNIdx, kIdx, nLoop, kLoop, strideKV, blockSize, startSeqOffset); + getKVOffset(gBlockTable, gBNextOffset, nLoopNextIdx, kLoopNextIdx, nLoop, kLoop, strideKV, blockSize, + startSeqNxtOffset); + bool firstItr = blockStackIdx == 0; + bool endItr = + (blockStackIdx + LOAB_BLOCK > UNIT_BLOCK_STACK_NUM - 1) || (nowNIdx + LOAB_BLOCK > nLoop - 1); + bool initMmad = blockStackIdx == 0; + bool pvCVItr = firstItr && kIdx == 0; + LayoutC layoutOTmpTemp(rowNum, embed, embed); + computePV(gA[gPOffset], gB[gBOffset], gC, gB[gBNextOffset], layoutA, layoutB, layoutOTmpTemp, + actualShape, actualNextShape, nowNIdx, nkBlockNextIdx, nkBlockLoop, firstItr, endItr, + initMmad, pvCVItr, softmaxFlag, preloadFlag); + gPOffset += actualShape.k(); + ++nkBlockNextIdx; + nowMaskTailS = 0; + preloadFlag = false; + } + } + } + + CATLASS_DEVICE + void computePV(AscendC::GlobalTensor const &gA, AscendC::GlobalTensor const &gB, + AscendC::GlobalTensor const &gC, AscendC::GlobalTensor const &gmNextBlockB, + LayoutA layoutA, LayoutB layoutB, LayoutC layoutC, GemmCoord actualShape, GemmCoord actualNextShape, + uint32_t nowIdx, uint32_t &nkblockIdx, uint32_t &nkblockLoop, bool firstItr, bool endItr, + bool initMmad, bool pvCVItr, Arch::CrossCoreFlag softmaxFlag, bool preloadFlag = false) + { + uint32_t MActual = actualShape.m(); + uint32_t kActual = actualShape.k(); + uint32_t nActual = actualShape.n(); + uint32_t mRound = RoundUp(MActual); + uint32_t kRound = RoundUp(kActual); + uint32_t nRound = RoundUp(nActual); + LayoutAInL1 layoutAInL1 = LayoutAInL1::template MakeLayout(mRound, kActual); + LayoutAInL0 layoutAInL0 = LayoutAInL0::template MakeLayout(mRound, kActual); + LayoutBInL1 layoutBInL1 = LayoutBInL1::template MakeLayout(kActual, nActual); + LayoutBInL0 layoutBInL0 = LayoutBInL0::template MakeLayout(kActual, nActual); + uint32_t l1KvPingPongFlag = nkblockIdx % 2; + uint32_t l0ABPingPongFlag = nkblockIdx % 2; + if (nkblockIdx == 1 || preloadFlag) { + auto layoutBTile = layoutB.GetTileLayout(MakeCoord(kActual, nActual)); + AscendC::WaitFlag(l1KvPingPongFlag + 2); + copyGmToL1B(l1BTensor[l1KvPingPongFlag], gB, layoutBInL1, layoutBTile); + AscendC::SetFlag(l1KvPingPongFlag + 2); + } + + AscendC::WaitFlag(l1KvPingPongFlag + 2); + AscendC::WaitFlag(2); + AscendC::WaitFlag(3); + copyL1ToL0B(l0BTensor, l1BTensor[l1KvPingPongFlag], layoutBInL0, layoutBInL1); + AscendC::SetFlag(l1KvPingPongFlag + 2); + + if (pvCVItr) { + Arch::CrossCoreWaitFlag(softmaxFlag); + } + AscendC::WaitFlag(l1KvPingPongFlag + 4); + auto layoutATile = layoutA.GetTileLayout(MakeCoord(MActual, kActual)); + copyGmToL1A(l1ATensor[l1KvPingPongFlag], gA, layoutAInL1, layoutATile); + AscendC::SetFlag(EVENT_ID5); + + if (nkblockIdx != nkblockLoop) { + uint32_t nNextActual = actualNextShape.n(); + uint32_t kNextActual = actualNextShape.k(); + LayoutBInL1 layoutBNextInL1 = LayoutBInL1::template MakeLayout(kNextActual, nNextActual); + auto layoutNextBTile = layoutB.GetTileLayout(MakeCoord(kNextActual, nNextActual)); + AscendC::WaitFlag(1 - l1KvPingPongFlag + 2); + copyGmToL1B(l1BTensor[1 - l1KvPingPongFlag], gmNextBlockB, layoutBNextInL1, layoutNextBTile); + AscendC::SetFlag(1 - l1KvPingPongFlag + 2); + } + + AscendC::WaitFlag(EVENT_ID5); + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + copyL1ToL0A(l0ATensor, l1ATensor[l1KvPingPongFlag], layoutAInL0, layoutAInL1); + AscendC::SetFlag(l1KvPingPongFlag + 4); + + AscendC::SetFlag(l0ABPingPongFlag); + AscendC::WaitFlag(l0ABPingPongFlag); + uint8_t unitFlag = 0b00; + if constexpr (!ENABLE_UNIT_FLAG_) { + if (firstItr) { + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + } + } else { + if (endItr) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + tileMmad(l0CTensor[0], l0ATensor, l0BTensor, mRound, nActual, kActual, initMmad, unitFlag); + // AscendC::PipeBarrier(); + AscendC::SetFlag(0); + AscendC::SetFlag(1); + AscendC::SetFlag(2); + AscendC::SetFlag(3); + auto layoutInL0C = LayoutCInL0::MakeLayoutInL0C(MakeCoord(MActual, nActual)); + if (endItr) { + if constexpr (!ENABLE_UNIT_FLAG_) { + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + copyL0CToGm(gC, l0CTensor[0], layoutC, layoutInL0C); + AscendC::SetFlag(0); + AscendC::SetFlag(1); + } else { + copyL0CToGm(gC, l0CTensor[0], layoutC, layoutInL0C, 0b11); + } + } + } + +protected: + /// Data members + AscendC::LocalTensor l1ATensor[STAGES]; + AscendC::LocalTensor l1BTensor[STAGES]; + AscendC::LocalTensor l0ATensor; + AscendC::LocalTensor l0BTensor; + AscendC::LocalTensor l0CTensor[STAGES]; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +//////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_MMAD_PV_TAIL_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_fai_pv_normal.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_fai_pv_normal.hpp new file mode 100644 index 00000000..72bd33bf --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_fai_pv_normal.hpp @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_MMAD_PV_HPP +#define CATLASS_GEMM_BLOCK_MMAD_PV_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +//////////////////////////////////////////////////////////////////// + +namespace Catlass::Gemm::Block { +//////////////////////////////////////////////////////////////////// + +template +struct BlockMmad, L1TileShape_, L0TileShape_, AType_, BType_, + CType_, BiasType_, TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2FAIPV; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; + static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; + static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; + static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; + static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES; + static constexpr uint32_t L0C_PINGPONG_BUF_SIZE = L0C_SIZE / STAGES; + static constexpr uint32_t BLOCK_SIZE = 16; + static constexpr uint32_t EMBED_SPLIT_SIZE = 128; + static constexpr uint32_t UNIT_BLOCK_STACK_NUM = 4; + static constexpr uint32_t KV_BASE_BLOCK = 512; + static constexpr uint32_t KV_SPLIT_SIZE = 128; + static constexpr uint32_t LOAB_BLOCK = 1; + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + /// Construct + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + // Allocate L1 memory space + l1ATensor = resource.l1Buf.template GetBufferByByte(l1BufAddrStart); + for (uint32_t i = 0; i < STAGES; i++) { + l0ATensor[i] = resource.l0ABuf.template GetBufferByByte(L0A_PINGPONG_BUF_SIZE * i); + l0BTensor[i] = resource.l0BBuf.template GetBufferByByte(L0B_PINGPONG_BUF_SIZE * i); + l0CTensor[i] = resource.l0CBuf.template GetBufferByByte(L0C_PINGPONG_BUF_SIZE * i); + l1BTensor[i] = + resource.l1Buf.template GetBufferByByte(l1BufAddrStart + L1A_SIZE * 2 * 2 + L1B_SIZE * i); + } + } + + /// Destructor + CATLASS_DEVICE + ~BlockMmad() {} + + CATLASS_DEVICE + void getBlockShape(GemmCoord &actualShape, uint32_t &nowNIdx, uint32_t &kIdx, uint32_t &nLoop, uint32_t &kLoop, + uint32_t &kvSeqlen, uint32_t &embed, bool firstBlock, uint32_t maskTailS = 0) + { + uint32_t nSplitSize = KV_SPLIT_SIZE * LOAB_BLOCK; + uint32_t embedSplitSize = EMBED_SPLIT_SIZE; + if (nowNIdx + LOAB_BLOCK > nLoop - 1) { + nSplitSize = kvSeqlen - nowNIdx * KV_SPLIT_SIZE; + } + if (firstBlock && maskTailS != 0) { + nSplitSize = nSplitSize - maskTailS; + } + if (kIdx == kLoop - 1) { + embedSplitSize = embed - kIdx * EMBED_SPLIT_SIZE; + } + actualShape[1] = embedSplitSize; + actualShape[2] = nSplitSize; + } + + CATLASS_DEVICE + void getKVOffset(AscendC::GlobalTensor &gBlockTable, uint32_t &kOffset, uint32_t &nowNIdx, uint32_t &kIdx, + uint32_t &nLoop, uint32_t &kLoop, uint32_t &strideKV, uint32_t &blockSize, uint32_t maskTailS = 0) + { + if (nowNIdx >= nLoop || kIdx >= kLoop) { + kOffset = 0; + } + if constexpr (PAGED_CACHE_FLAG_) { + uint32_t blockTableId = gBlockTable.GetValue(nowNIdx); + kOffset = blockTableId * blockSize * strideKV + maskTailS * strideKV + kIdx * EMBED_SPLIT_SIZE; + } else { + kOffset = nowNIdx * KV_SPLIT_SIZE * strideKV + kIdx * EMBED_SPLIT_SIZE; + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gA, AscendC::GlobalTensor gB, + AscendC::GlobalTensor gC, AscendC::GlobalTensor gBlockTable, LayoutA layoutA, + LayoutB layoutB, GemmCoord actualOriShape, uint32_t &nIdx, uint32_t &nLoop, uint32_t &blockSize, + uint32_t kvSeqlen, uint32_t strideKV, Arch::CrossCoreFlag softmaxFlag) + { + // Arch::CrossCoreWaitFlag(softmaxFlag); + uint32_t embed = actualOriShape[1]; + uint32_t kLoop = CeilDiv(embed); + uint32_t rowNum = layoutA.shape(0); + uint32_t blockN = layoutA.shape(1); + GemmCoord actualShape{rowNum, 0, 0}; + GemmCoord actualNextShape{rowNum, 0, 0}; + uint32_t nkBlockLoop = (nLoop + LOAB_BLOCK - 1) / LOAB_BLOCK * kLoop; // gap + uint32_t nkBlockNextIdx = (nIdx + LOAB_BLOCK - 1) / LOAB_BLOCK * kLoop + 1; // gap + uint32_t gBOffset = 0; + uint32_t gBNextOffset = 0; + for (uint32_t kIdx = 0; kIdx < kLoop; kIdx++) { + for (uint32_t blockStackIdx = 0; (blockStackIdx < UNIT_BLOCK_STACK_NUM) && ((nIdx + blockStackIdx) < nLoop); + blockStackIdx += LOAB_BLOCK) { + uint32_t nowNIdx = nIdx + blockStackIdx; + uint32_t kLoopNextIdx = + (nkBlockNextIdx % (kLoop * UNIT_BLOCK_STACK_NUM)) / (UNIT_BLOCK_STACK_NUM / LOAB_BLOCK); + uint32_t nLoopNextIdx = + (nkBlockNextIdx % (kLoop * UNIT_BLOCK_STACK_NUM)) % (UNIT_BLOCK_STACK_NUM / LOAB_BLOCK) + + nkBlockNextIdx / (kLoop * UNIT_BLOCK_STACK_NUM) * UNIT_BLOCK_STACK_NUM; + getBlockShape(actualShape, nowNIdx, kIdx, nLoop, kLoop, kvSeqlen, embed, nowNIdx == nIdx); + getBlockShape(actualNextShape, nLoopNextIdx, kLoopNextIdx, nLoop, kLoop, kvSeqlen, embed, + nowNIdx == nIdx); + getKVOffset(gBlockTable, gBOffset, nowNIdx, kIdx, nLoop, kLoop, strideKV, blockSize); + getKVOffset(gBlockTable, gBNextOffset, nLoopNextIdx, kLoopNextIdx, nLoop, kLoop, strideKV, blockSize); + bool firstItr = blockStackIdx == 0; + bool endItr = + (blockStackIdx + LOAB_BLOCK > UNIT_BLOCK_STACK_NUM - 1) || (nowNIdx + LOAB_BLOCK > nLoop - 1); + bool initMmad = blockStackIdx == 0; + bool pvCVItr = firstItr && kIdx == 0; + LayoutC layoutOTmpTemp(rowNum, embed, embed); + computePV(gA, gB[gBOffset], gC, gB[gBNextOffset], layoutA, layoutB, layoutOTmpTemp, actualShape, + actualNextShape, blockStackIdx, nkBlockNextIdx, nkBlockLoop, firstItr, endItr, initMmad, + pvCVItr, softmaxFlag); + ++nkBlockNextIdx; + } + } + } + + CATLASS_DEVICE + void computePV(AscendC::GlobalTensor const &gA, AscendC::GlobalTensor const &gB, + AscendC::GlobalTensor const &gC, AscendC::GlobalTensor const &gmNextBlockB, + LayoutA layoutA, LayoutB layoutB, LayoutC layoutC, GemmCoord actualShape, GemmCoord actualNextShape, + uint32_t nowIdx, uint32_t &nkblockIdx, uint32_t &nkblockLoop, bool firstItr, bool endItr, + bool initMmad, bool pvCVItr, Arch::CrossCoreFlag softmaxFlag, bool preloadFlag = false) + { + uint32_t MActual = actualShape.m(); + uint32_t kActual = actualShape.k(); + uint32_t nActual = actualShape.n(); + uint32_t mRound = RoundUp(MActual); + uint32_t kRound = RoundUp(kActual); + uint32_t nRound = RoundUp(nActual); + LayoutAInL1 layoutAInL1 = LayoutAInL1::template MakeLayout(mRound, (uint32_t)512); + LayoutAInL0 layoutAInL0 = LayoutAInL0::template MakeLayout(mRound, kActual); + LayoutBInL1 layoutBInL1 = LayoutBInL1::template MakeLayout(kActual, nActual); + LayoutBInL0 layoutBInL0 = LayoutBInL0::template MakeLayout(kActual, nActual); + uint32_t l1KvPingPongFlag = nkblockIdx % 2; + uint32_t l0ABPingPongFlag = nkblockIdx % 2; + if (nkblockIdx == 1 || preloadFlag) { + auto layoutBTile = layoutB.GetTileLayout(MakeCoord(kActual, nActual)); + AscendC::WaitFlag(l1KvPingPongFlag + 2); + copyGmToL1B(l1BTensor[l1KvPingPongFlag], gB, layoutBInL1, layoutBTile); + AscendC::SetFlag(l1KvPingPongFlag + 2); + } + + AscendC::WaitFlag(l1KvPingPongFlag + 2); + AscendC::WaitFlag(l0ABPingPongFlag + 2); + copyL1ToL0B(l0BTensor[l0ABPingPongFlag], l1BTensor[l1KvPingPongFlag], layoutBInL0, layoutBInL1); + AscendC::SetFlag(l1KvPingPongFlag + 2); + + if (pvCVItr) { + Arch::CrossCoreWaitFlag(softmaxFlag); + AscendC::WaitFlag(4); + AscendC::WaitFlag(5); + auto layoutATile = layoutA.GetTileLayout(MakeCoord(MActual, (uint32_t)512)); + copyGmToL1A(l1ATensor, gA, layoutAInL1, layoutATile); + AscendC::SetFlag(4); + AscendC::WaitFlag(4); + } + + if (nkblockIdx != nkblockLoop) { + uint32_t nNextActual = actualNextShape.n(); + uint32_t kNextActual = actualNextShape.k(); + LayoutBInL1 layoutBNextInL1 = LayoutBInL1::template MakeLayout(kNextActual, nNextActual); + auto layoutNextBTile = layoutB.GetTileLayout(MakeCoord(kNextActual, nNextActual)); + AscendC::WaitFlag(1 - l1KvPingPongFlag + 2); + copyGmToL1B(l1BTensor[1 - l1KvPingPongFlag], gmNextBlockB, layoutBNextInL1, layoutNextBTile); + AscendC::SetFlag(1 - l1KvPingPongFlag + 2); + } + + AscendC::WaitFlag(l0ABPingPongFlag); + copyL1ToL0A(l0ATensor[l0ABPingPongFlag], l1ATensor[mRound * 128 * nowIdx], layoutAInL0, layoutAInL1); + + AscendC::SetFlag(l0ABPingPongFlag); + AscendC::WaitFlag(l0ABPingPongFlag); + uint8_t unitFlag = 0b00; + if constexpr (!ENABLE_UNIT_FLAG_) { + if (firstItr) { + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + } + } else { + if (endItr) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + tileMmad(l0CTensor[0], l0ATensor[l0ABPingPongFlag], l0BTensor[l0ABPingPongFlag], mRound, nActual, kActual, + initMmad, unitFlag); + // AscendC::PipeBarrier(); + AscendC::SetFlag(l0ABPingPongFlag); + AscendC::SetFlag(l0ABPingPongFlag + 2); + auto layoutInL0C = LayoutCInL0::MakeLayoutInL0C(MakeCoord(MActual, nActual)); + if (endItr) { + AscendC::SetFlag(4); + AscendC::SetFlag(5); + if constexpr (!ENABLE_UNIT_FLAG_) { + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + copyL0CToGm(gC, l0CTensor[0], layoutC, layoutInL0C); + AscendC::SetFlag(0); + AscendC::SetFlag(1); + } else { + copyL0CToGm(gC, l0CTensor[0], layoutC, layoutInL0C, 0b11); + } + } + } + +protected: + /// Data members + AscendC::LocalTensor l1ATensor; + AscendC::LocalTensor l1BTensor[STAGES]; + AscendC::LocalTensor l0ATensor[STAGES]; + AscendC::LocalTensor l0BTensor[STAGES]; + AscendC::LocalTensor l0CTensor[STAGES]; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +//////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_MMAD_PV_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_fai_qk_head_tail.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_fai_qk_head_tail.hpp new file mode 100644 index 00000000..2eb1184f --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_fai_qk_head_tail.hpp @@ -0,0 +1,286 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_MMAD_QK_TAIL_HPP +#define CATLASS_GEMM_BLOCK_MMAD_QK_TAIL_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +//////////////////////////////////////////////////////////////////// + +namespace Catlass::Gemm::Block { +//////////////////////////////////////////////////////////////////// + +template +struct BlockMmad, L1TileShape_, L0TileShape_, AType_, BType_, + CType_, BiasType_, TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2FAITailQK; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; + static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; + static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; + static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; + static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES; + static constexpr uint32_t L0C_PINGPONG_BUF_SIZE = L0C_SIZE / STAGES; + static constexpr uint32_t BLOCK_SIZE = 16; + static constexpr uint32_t EMBED_SPLIT_SIZE = 128; + static constexpr uint32_t UNIT_BLOCK_STACK_NUM = 4; + static constexpr uint32_t KV_BASE_BLOCK = 512; + static constexpr uint32_t KV_SPLIT_SIZE = 128; + + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + // Allocate L1 memory space + l1ATensor = resource.l1Buf.template GetBufferByByte(l1BufAddrStart); + for (uint32_t i = 0; i < STAGES; i++) { + l1BTensor[i] = resource.l1Buf.template GetBufferByByte(l1BufAddrStart + L1A_SIZE + L1B_SIZE * i); + l0ATensor[i] = resource.l0ABuf.template GetBufferByByte(L0A_PINGPONG_BUF_SIZE * i); + l0BTensor[i] = resource.l0BBuf.template GetBufferByByte(L0B_PINGPONG_BUF_SIZE * i); + } + l0CTensor = resource.l0CBuf.template GetBufferByByte(0); + } + + /// Destructor + CATLASS_DEVICE + ~BlockMmad() {} + + CATLASS_DEVICE + void loadQGM(AscendC::GlobalTensor gA, LayoutA layoutA, uint32_t rowNum, uint32_t &singleGroupHeads, + uint32_t &qHeads) + { + uint32_t embed = layoutA.shape(1); + uint32_t rowNumRound = RoundUp(rowNum); + uint32_t tokenNumPerGroup = rowNum / singleGroupHeads; + auto layoutSingleANd = layoutA.GetTileLayout(MakeCoord(singleGroupHeads, embed)); + LayoutAInL1 layoutAInL1 = LayoutAInL1::template MakeLayout(rowNum, embed); + copyGmToL1A(l1ATensor, gA, layoutAInL1, layoutSingleANd, tokenNumPerGroup, qHeads * embed, tokenNumPerGroup, + BLOCK_SIZE, rowNumRound); + AscendC::SetFlag(EVENT_ID3); + AscendC::WaitFlag(EVENT_ID3); + } + + CATLASS_DEVICE + void getBlockShape(GemmCoord &actualShape, uint32_t &nowNIdx, uint32_t &kIdx, uint32_t &nLoop, uint32_t &kLoop, + uint32_t &kvSeqlen, uint32_t &embed, bool firstBlock, uint32_t maskTailS = 0) + { + uint32_t nSplitSize = KV_SPLIT_SIZE; + uint32_t embedSplitSize = EMBED_SPLIT_SIZE; + if (nowNIdx == nLoop - 1) { + nSplitSize = kvSeqlen - nowNIdx * KV_SPLIT_SIZE; + } + if (firstBlock && maskTailS != 0) { + nSplitSize = nSplitSize - maskTailS; + } + if (kIdx == kLoop - 1) { + embedSplitSize = embed - kIdx * EMBED_SPLIT_SIZE; + } + actualShape[1] = nSplitSize; + actualShape[2] = embedSplitSize; + } + + CATLASS_DEVICE + void getKVOffset(AscendC::GlobalTensor &gBlockTable, uint32_t &kOffset, uint32_t &nowNIdx, uint32_t &kIdx, + uint32_t &nLoop, uint32_t &kLoop, uint32_t &strideKV, uint32_t &blockSize, uint32_t maskTailS = 0) + { + if (nowNIdx >= nLoop || kIdx >= kLoop) { + kOffset = 0; + } + if constexpr (PAGED_CACHE_FLAG_) { + uint32_t blockTableId = gBlockTable.GetValue(nowNIdx); + kOffset = blockTableId * blockSize * strideKV + maskTailS * strideKV + kIdx * EMBED_SPLIT_SIZE; + } else { + kOffset = nowNIdx * KV_SPLIT_SIZE * strideKV + kIdx * EMBED_SPLIT_SIZE; + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gA, AscendC::GlobalTensor gB, + AscendC::GlobalTensor gC, AscendC::GlobalTensor gBlockTable, LayoutA layoutA, + LayoutB layoutB, GemmCoord actualOriShape, uint32_t &nIdx, uint32_t &nLoop, uint32_t &blockSize, + uint32_t kvSeqlen, uint32_t strideKV, uint32_t maskTailS, bool preloadFlag) + { + uint32_t rowNum = actualOriShape[0]; + uint32_t embed = actualOriShape[2]; + uint32_t kLoop = CeilDiv(embed); + uint32_t nkBlockLoop = nLoop * kLoop; + GemmCoord actualShape{rowNum, 0, 0}; + GemmCoord actualNextShape{rowNum, 0, 0}; + uint32_t nkBlockNextIdx = nIdx * kLoop + 1; + uint32_t gBOffset = 0; + uint32_t gBNextOffset = 0; + uint32_t stackTile = 0; + uint32_t preStackTile = 0; + for (uint32_t blockStackIdx = 0; (blockStackIdx < UNIT_BLOCK_STACK_NUM) && ((nIdx + blockStackIdx) < nLoop); + ++blockStackIdx) { + for (uint32_t kIdx = 0; kIdx < kLoop; kIdx++) { + uint32_t nowNIdx = nIdx + blockStackIdx; + uint32_t nLoopNextIdx = nkBlockNextIdx / kLoop; + uint32_t kLoopNextIdx = nkBlockNextIdx % kLoop; + uint32_t gCOffset = stackTile; + uint32_t startSeqOffset = nowNIdx == nIdx ? maskTailS : 0; + uint32_t startSeqNxtOffset = nLoopNextIdx == nIdx ? maskTailS : 0; + getBlockShape(actualShape, nowNIdx, kIdx, nLoop, kLoop, kvSeqlen, embed, nowNIdx == nIdx, maskTailS); + getBlockShape(actualNextShape, nLoopNextIdx, kLoopNextIdx, nLoop, kLoop, kvSeqlen, embed, + nLoopNextIdx == nIdx, maskTailS); + getKVOffset(gBlockTable, gBOffset, nowNIdx, kIdx, nLoop, kLoop, strideKV, blockSize, startSeqOffset); + getKVOffset(gBlockTable, gBNextOffset, nLoopNextIdx, kLoopNextIdx, nLoop, kLoop, strideKV, blockSize, + startSeqNxtOffset); + bool firstItr = kIdx == 0; + bool endItr = kIdx == kLoop - 1; + bool firstQtr = blockStackIdx == 0; + bool endQItr = + ((nowNIdx == nLoop - 1) || (blockStackIdx == UNIT_BLOCK_STACK_NUM - 1)) && (kIdx == kLoop - 1); + bool initMmad = kIdx == 0; + preStackTile = stackTile; + stackTile += actualShape[1]; + LayoutC layOutSTemp(rowNum, stackTile, 512); + computeQK(gA, gB[gBOffset], gC[gCOffset], gB[gBNextOffset], layoutA, layoutB, layOutSTemp, actualShape, + actualNextShape, blockStackIdx, nkBlockNextIdx, nkBlockLoop, firstItr, endItr, initMmad, + firstQtr, endQItr, preloadFlag); + ++nkBlockNextIdx; + preloadFlag = false; + } + maskTailS = 0; + } + } + + CATLASS_DEVICE void computeQK(AscendC::GlobalTensor const &gA, AscendC::GlobalTensor const &gB, + AscendC::GlobalTensor const &gC, + AscendC::GlobalTensor const &gmNextBlockB, LayoutA layoutA, LayoutB layoutB, + LayoutC layoutC, GemmCoord actualShape, GemmCoord actualNextShape, uint32_t nowIdx, + uint32_t &nkblockIdx, uint32_t &nkblockLoop, bool firstItr, bool endItr, + bool initMmad, bool firstQItr, bool endQItr, bool preloadFlag) + { + uint32_t mActual = actualShape.m(); + uint32_t kActual = actualShape.k(); + uint32_t nActual = actualShape.n(); + uint32_t mRound = RoundUp(mActual); + uint32_t kRound = RoundUp(kActual); + uint32_t nRound = RoundUp(nActual); + LayoutAInL1 layoutAInL1 = LayoutAInL1::template MakeLayout(mRound, kActual); // embed + LayoutAInL0 layoutAInL0 = LayoutAInL0::template MakeLayout(mRound, kActual); + LayoutBInL1 layoutBInL1 = LayoutBInL1::template MakeLayout(kActual, nActual); + LayoutBInL0 layoutBInL0 = LayoutBInL0::template MakeLayout(kActual, nRound); + uint32_t locPingPongFlag = nkblockIdx % 2; + uint32_t l1KvPingPongFlag = nkblockIdx % 2; + uint32_t l0ABPingPongFlag = nkblockIdx % 2; + if (preloadFlag) { + auto layoutBTile = layoutB.GetTileLayout(MakeCoord(kActual, nActual)); + AscendC::WaitFlag(l1KvPingPongFlag); + copyGmToL1B(l1BTensor[l1KvPingPongFlag], gB, layoutBInL1, layoutBTile); + AscendC::SetFlag(l1KvPingPongFlag); + } + if (nkblockIdx != nkblockLoop) { + uint32_t nNextActual = actualNextShape.n(); + uint32_t kNextActual = actualNextShape.k(); + LayoutBInL1 layoutBNextInL1 = LayoutBInL1::template MakeLayout(kNextActual, nNextActual); + auto layoutNextBTile = layoutB.GetTileLayout(MakeCoord(kNextActual, nNextActual)); + AscendC::WaitFlag(1 - l1KvPingPongFlag); + copyGmToL1B(l1BTensor[1 - l1KvPingPongFlag], gmNextBlockB, layoutBNextInL1, layoutNextBTile); + AscendC::SetFlag(1 - l1KvPingPongFlag); + } + AscendC::WaitFlag(l0ABPingPongFlag); + copyL1ToL0A(l0ATensor[l0ABPingPongFlag], l1ATensor, layoutAInL0, layoutAInL1); + + AscendC::WaitFlag(l1KvPingPongFlag); + AscendC::WaitFlag(l0ABPingPongFlag + 2); + copyL1ToL0B(l0BTensor[l0ABPingPongFlag], l1BTensor[l1KvPingPongFlag], layoutBInL0, layoutBInL1); + AscendC::SetFlag(l1KvPingPongFlag); + + AscendC::SetFlag(l0ABPingPongFlag); + AscendC::WaitFlag(l0ABPingPongFlag); + uint8_t unitFlag = 0b00; + if constexpr (!ENABLE_UNIT_FLAG_) { + if (firstItr) { + AscendC::WaitFlag(locPingPongFlag); + } + } else { + unitFlag = 0b11; + } + tileMmad(l0CTensor[locPingPongFlag * mRound * 128], l0ATensor[l0ABPingPongFlag], l0BTensor[l0ABPingPongFlag], + mRound, nActual, kActual, initMmad, unitFlag); + // AscendC::PipeBarrier(); + AscendC::SetFlag(l0ABPingPongFlag); + AscendC::SetFlag(l0ABPingPongFlag + 2); + auto layoutInL0C = LayoutCInL0::MakeLayoutInL0C(MakeCoord(mActual, nActual)); + if (endItr) { + if constexpr (!ENABLE_UNIT_FLAG_) { + AscendC::SetFlag(locPingPongFlag); + AscendC::WaitFlag(locPingPongFlag); + copyL0CToGm(gC, l0CTensor[locPingPongFlag * mRound * 128], layoutC, layoutInL0C); + AscendC::SetFlag(locPingPongFlag); + } else { + copyL0CToGm(gC, l0CTensor[locPingPongFlag * mRound * 128], layoutC, layoutInL0C, unitFlag); + } + } + } + +protected: + /// Data members + AscendC::LocalTensor l1ATensor; + AscendC::LocalTensor l1BTensor[STAGES]; + AscendC::LocalTensor l0ATensor[STAGES]; + AscendC::LocalTensor l0BTensor[STAGES]; + AscendC::LocalTensor l0CTensor; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +//////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_MMAD_QK_TAIL_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_fai_qk_normal.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_fai_qk_normal.hpp new file mode 100644 index 00000000..335fbbfb --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_fai_qk_normal.hpp @@ -0,0 +1,318 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_MMAD_QK_HPP +#define CATLASS_GEMM_BLOCK_MMAD_QK_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +//////////////////////////////////////////////////////////////////// + +namespace Catlass::Gemm::Block { +//////////////////////////////////////////////////////////////////// + +template +struct BlockMmad, L1TileShape_, L0TileShape_, AType_, BType_, + CType_, BiasType_, TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2FAIQK; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; + static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; + static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; + static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; + static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES; + static constexpr uint32_t L0C_PINGPONG_BUF_SIZE = L0C_SIZE / STAGES; + static constexpr uint32_t BLOCK_SIZE = 16; + static constexpr uint32_t EMBED_SPLIT_SIZE = 128; + static constexpr uint32_t UNIT_BLOCK_STACK_NUM = 4; + static constexpr uint32_t KV_BASE_BLOCK = 512; + static constexpr uint32_t KV_SPLIT_SIZE = 128; + + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + // Allocate L1 memory space + l1ATensor = resource.l1Buf.template GetBufferByByte(l1BufAddrStart); + for (uint32_t i = 0; i < STAGES; i++) { + l1BTensor[i] = resource.l1Buf.template GetBufferByByte(l1BufAddrStart + L1A_SIZE + L1B_SIZE * i); + l0ATensor[i] = resource.l0ABuf.template GetBufferByByte(L0A_PINGPONG_BUF_SIZE * i); + l0BTensor[i] = resource.l0BBuf.template GetBufferByByte(L0B_PINGPONG_BUF_SIZE * i); + } + l0CTensor = resource.l0CBuf.template GetBufferByByte(0); + } + + CATLASS_DEVICE + ~BlockMmad() {} + + CATLASS_DEVICE + void loadQGM(AscendC::GlobalTensor gA, LayoutA layoutA, uint32_t rowNum, uint32_t &singleGroupHeads, + uint32_t &qHeads) + { + uint32_t embed = layoutA.shape(1); + uint32_t rowNumRound = RoundUp(rowNum); + uint32_t tokenNumPerGroup = rowNum / singleGroupHeads; + auto layoutSingleANd = layoutA.GetTileLayout(MakeCoord(singleGroupHeads, embed)); + LayoutAInL1 layoutAInL1 = LayoutAInL1::template MakeLayout(rowNum, embed); + copyGmToL1A(l1ATensor, gA, layoutAInL1, layoutSingleANd, tokenNumPerGroup, qHeads * embed, tokenNumPerGroup, + BLOCK_SIZE, rowNumRound); + // AscendC::Nd2NzParams intriParams; + // intriParams.nValue = singleGroupHeads; + // intriParams.dValue = embed; + // intriParams.srcDValue = embed; + // intriParams.dstNzNStride = tokenNumPerGroup; + // intriParams.dstNzC0Stride = rowNumRound; + // intriParams.ndNum = tokenNumPerGroup; + // intriParams.srcNdMatrixStride = qHeads * embed; + // intriParams.dstNzMatrixStride = 16; + // AscendC::DataCopy(l1ATensor, gA, intriParams); + AscendC::SetFlag(EVENT_ID3); + AscendC::WaitFlag(EVENT_ID3); + } + + CATLASS_DEVICE + void getBlockShape(GemmCoord &actualShape, uint32_t &nowNIdx, uint32_t &kIdx, uint32_t &nLoop, uint32_t &kLoop, + uint32_t &kvSeqlen, uint32_t &embed, bool firstBlock, uint32_t maskTailS = 0) + { + uint32_t nSplitSize = KV_SPLIT_SIZE; + uint32_t embedSplitSize = EMBED_SPLIT_SIZE; + if (nowNIdx == nLoop - 1) { + nSplitSize = kvSeqlen - nowNIdx * KV_SPLIT_SIZE; + } + if (firstBlock && maskTailS != 0) { + nSplitSize = nSplitSize - maskTailS; + } + // } + if (kIdx == kLoop - 1) { + embedSplitSize = embed - kIdx * EMBED_SPLIT_SIZE; + } + actualShape[1] = nSplitSize; + actualShape[2] = embedSplitSize; + } + + CATLASS_DEVICE + void getKVOffset(AscendC::GlobalTensor &gBlockTable, uint32_t &kOffset, uint32_t &nowNIdx, uint32_t &kIdx, + uint32_t &nLoop, uint32_t &kLoop, uint32_t &strideKV, uint32_t &blockSize, uint32_t maskTailS = 0) + { + if (nowNIdx >= nLoop || kIdx >= kLoop) { + kOffset = 0; + } + if constexpr (PAGED_CACHE_FLAG_) { + uint32_t blockTableId = gBlockTable.GetValue(nowNIdx); + kOffset = blockTableId * blockSize * strideKV + maskTailS * strideKV + kIdx * EMBED_SPLIT_SIZE; + } else { + kOffset = nowNIdx * KV_SPLIT_SIZE * strideKV + kIdx * EMBED_SPLIT_SIZE; + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gA, AscendC::GlobalTensor gB, + AscendC::GlobalTensor gC, AscendC::GlobalTensor gBlockTable, LayoutA layoutA, + LayoutB layoutB, GemmCoord actualOriShape, uint32_t &nIdx, uint32_t &nLoop, uint32_t &blockSize, + uint32_t kvSeqlen, uint32_t strideKV) + { + uint32_t rowNum = actualOriShape[0]; + uint32_t embed = actualOriShape[2]; + uint32_t kLoop = CeilDiv(embed); + uint32_t nkBlockLoop = nLoop * kLoop; + GemmCoord actualShape{rowNum, 0, 0}; + GemmCoord actualNextShape{rowNum, 0, 0}; + uint32_t nkBlockNextIdx = nIdx * kLoop + 1; + uint32_t gBOffset = 0; + uint32_t gBNextOffset = 0; + uint32_t stackTile = 0; + for (uint32_t blockStackIdx = 0; (blockStackIdx < UNIT_BLOCK_STACK_NUM) && ((nIdx + blockStackIdx) < nLoop); + ++blockStackIdx) { + for (uint32_t kIdx = 0; kIdx < kLoop; kIdx++) { + uint32_t nowNIdx = nIdx + blockStackIdx; + uint32_t nLoopNextIdx = nkBlockNextIdx / kLoop; + uint32_t kLoopNextIdx = nkBlockNextIdx % kLoop; + uint32_t gCOffset = blockStackIdx / 2 * 2 * KV_SPLIT_SIZE; + getBlockShape(actualShape, nowNIdx, kIdx, nLoop, kLoop, kvSeqlen, embed, nowNIdx == nIdx); + getBlockShape(actualNextShape, nLoopNextIdx, kLoopNextIdx, nLoop, kLoop, kvSeqlen, embed, + nLoopNextIdx == nIdx); + getKVOffset(gBlockTable, gBOffset, nowNIdx, kIdx, nLoop, kLoop, strideKV, blockSize); + getKVOffset(gBlockTable, gBNextOffset, nLoopNextIdx, kLoopNextIdx, nLoop, kLoop, strideKV, blockSize); + bool firstItr = ((blockStackIdx % 2) == 0) && (kIdx == 0); + bool endItr = (((blockStackIdx % 2) == 1) || (nowNIdx == nLoop - 1)) && (kIdx == kLoop - 1); + bool firstQtr = blockStackIdx == 0; + bool endQItr = + ((nowNIdx == nLoop - 1) || (blockStackIdx == UNIT_BLOCK_STACK_NUM - 1)) && (kIdx == kLoop - 1); + int cc = 1; + bool initMmad = kIdx == 0; + stackTile += actualShape[1]; + LayoutC layOutSTemp(rowNum, stackTile, 512); + // LayoutC layOutSTemp(rowNum, stackTile, 512); + // AscendC::printf("firstItr:%d\n", firstItr); + // AscendC::printf("endItr:%d\n", endItr); + // AscendC::printf("initMmad:%d\n", initMmad); + // AscendC::printf("stackTile:%d\n", stackTile); + // AscendC::printf("blockStackIdx:%d\n", blockStackIdx); + // AscendC::printf("gBOffset:%d\n", gBOffset); + // AscendC::printf("gBNextOffset:%d\n", gBNextOffset); + // AscendC::printf("actualShape[0]:%d\n", actualShape[0]); + // AscendC::printf("actualShape[1]:%d\n", actualShape[1]); + // AscendC::printf("actualShape[2]:%d\n", actualShape[2]); + // AscendC::printf("actualShape.m():%d\n", actualShape.m()); + // AscendC::printf("actualShape.n():%d\n", actualShape.n()); + // AscendC::printf("actualShape.k():%d\n", actualShape.k()); + // AscendC::printf("actualNextShape[0]:%d\n", actualNextShape[0]); + // AscendC::printf("actualNextShape[1]:%d\n", actualNextShape[1]); + // AscendC::printf("actualNextShape[2]:%d\n", actualNextShape[2]); + computeQK(gA, gB[gBOffset], gC[gCOffset], gB[gBNextOffset], layoutA, layoutB, layOutSTemp, actualShape, + actualNextShape, blockStackIdx, nkBlockNextIdx, nkBlockLoop, firstItr, endItr, initMmad, + firstQtr, endQItr); + ++nkBlockNextIdx; + if (endItr) { + stackTile = 0; + } + } + } + } + + CATLASS_DEVICE void computeQK(AscendC::GlobalTensor const &gA, AscendC::GlobalTensor const &gB, + AscendC::GlobalTensor const &gC, + AscendC::GlobalTensor const &gmNextBlockB, LayoutA layoutA, LayoutB layoutB, + LayoutC layoutC, GemmCoord actualShape, GemmCoord actualNextShape, uint32_t nowIdx, + uint32_t &nkblockIdx, uint32_t &nkblockLoop, bool firstItr, bool endItr, + bool initMmad, bool firstQItr, bool endQItr) + { + uint32_t mActual = actualShape.m(); + uint32_t kActual = actualShape.k(); + uint32_t nActual = actualShape.n(); + uint32_t mRound = RoundUp(mActual); + uint32_t kRound = RoundUp(kActual); + uint32_t nRound = RoundUp(nActual); + LayoutAInL1 layoutAInL1 = LayoutAInL1::template MakeLayout(mRound, kActual); // embed + LayoutAInL0 layoutAInL0 = LayoutAInL0::template MakeLayout(mRound, kActual); + LayoutBInL1 layoutBInL1 = LayoutBInL1::template MakeLayout(kActual, nActual); + LayoutBInL0 layoutBInL0 = LayoutBInL0::template MakeLayout(kActual, nRound); + uint32_t locPingPongFlag = nowIdx % 2; + uint32_t l1KvPingPongFlag = nkblockIdx % 2; + uint32_t l0ABPingPongFlag = nkblockIdx % 2; + if (nkblockIdx == 1) { + auto layoutBTile = layoutB.GetTileLayout(MakeCoord(kActual, nActual)); + AscendC::WaitFlag(l1KvPingPongFlag); + copyGmToL1B(l1BTensor[l1KvPingPongFlag], gB, layoutBInL1, layoutBTile); + AscendC::SetFlag(l1KvPingPongFlag); + } + if (nkblockIdx != nkblockLoop) { + uint32_t nNextActual = actualNextShape.n(); + uint32_t kNextActual = actualNextShape.k(); + LayoutBInL1 layoutBNextInL1 = LayoutBInL1::template MakeLayout(kNextActual, nNextActual); + auto layoutNextBTile = layoutB.GetTileLayout(MakeCoord(kNextActual, nNextActual)); + AscendC::WaitFlag(1 - l1KvPingPongFlag); + copyGmToL1B(l1BTensor[1 - l1KvPingPongFlag], gmNextBlockB, layoutBNextInL1, layoutNextBTile); + AscendC::SetFlag(1 - l1KvPingPongFlag); + } + if (firstQItr) { + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + copyL1ToL0A(l0ATensor[0], l1ATensor, layoutAInL0, layoutAInL1); + } + + AscendC::WaitFlag(l1KvPingPongFlag); + AscendC::WaitFlag(l0ABPingPongFlag + 2); + copyL1ToL0B(l0BTensor[l0ABPingPongFlag], l1BTensor[l1KvPingPongFlag], layoutBInL0, layoutBInL1); + AscendC::SetFlag(l1KvPingPongFlag); + + AscendC::SetFlag(l0ABPingPongFlag); + AscendC::WaitFlag(l0ABPingPongFlag); + uint8_t unitFlag = 0b00; + if constexpr (!ENABLE_UNIT_FLAG_) { + if (firstItr) { + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + } + } else { + unitFlag = 0b11; + } + tileMmad(l0CTensor[locPingPongFlag * mRound * 128], l0ATensor[0], l0BTensor[l0ABPingPongFlag], mRound, nActual, + kActual, initMmad, unitFlag); + // AscendC::PipeBarrier(); + if (endQItr) { + AscendC::SetFlag(0); + AscendC::SetFlag(1); + } + AscendC::SetFlag(l0ABPingPongFlag + 2); + auto layoutInL0C = LayoutCInL0::MakeLayoutInL0C(MakeCoord(mActual, (uint32_t)256)); + if (endItr) { + if constexpr (!ENABLE_UNIT_FLAG_) { + AscendC::SetFlag(l1KvPingPongFlag); + AscendC::WaitFlag(l1KvPingPongFlag); + copyL0CToGm(gC, l0CTensor, layoutC, layoutInL0C); + AscendC::SetFlag(0); + AscendC::SetFlag(1); + } else { + copyL0CToGm(gC, l0CTensor, layoutC, layoutInL0C, unitFlag); + } + } + } + +protected: + /// Data members + AscendC::LocalTensor l1ATensor; + AscendC::LocalTensor l1BTensor[STAGES]; + AscendC::LocalTensor l0ATensor[STAGES]; + AscendC::LocalTensor l0BTensor[STAGES]; + AscendC::LocalTensor l0CTensor; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +//////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_MMAD_QK_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_gemm.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_gemm.hpp new file mode 100644 index 00000000..b0a66b25 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_gemm.hpp @@ -0,0 +1,327 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_BLOCK_BLOCK_MMAD_GEMM_HPP +#define CATLASS_BLOCK_BLOCK_MMAD_GEMM_HPP + +#include "catlass/catlass.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/arch/resource.hpp" + +namespace Catlass::Gemm::Block { + +template +struct BlockGemm, L1TileShape_, L0TileShape_, + AType_, BType_, CType_, BiasType_, TileCopy_, TileMmad_> { +public: + using DispatchPolicy = Gemm::GemmAtlasA2; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyGmToL1A::LayoutDst; + using LayoutBInL1 = typename CopyGmToL1B::LayoutDst; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K; + static constexpr bool ENABLE_ABBA = DispatchPolicy::ENABLE_ABBA; + const uint32_t L1Size = ArchTag::L1_SIZE; + const uint32_t L1ASize = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + const uint32_t L1BSize = L1TileShape::K * L1TileShape::N * sizeof(ElementB); + const uint32_t cSize = L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator); + const uint32_t BlockCnt = L1TileShape::M * L1TileShape::N; + const uint32_t L0ASize = ArchTag::L0A_SIZE; + const uint32_t L0BSize = ArchTag::L0B_SIZE; + const uint32_t L0CSize = ArchTag::L0C_SIZE; + const uint32_t L0A_PINGPONG_BUF_LEN = (L0ASize / STAGES); + const uint32_t L0B_PINGPONG_BUF_LEN = (L0BSize / STAGES); + const uint32_t l0CBlockNum = ArchTag::L0C_SIZE / cSize; + + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + CATLASS_DEVICE + BlockGemm(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + uint32_t l1AOffset = l1BufAddrStart; + uint32_t l1BOffset = l1BufAddrStart + L1ASize * STAGES; + for (uint32_t i = 0; i < STAGES; i++) { + l1ATensor[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1ASize * i); + l1BTensor[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1BSize * i); + l0ATensor[i] = resource.l0ABuf.template GetBufferByByte(L0A_PINGPONG_BUF_LEN * i); + l0BTensor[i] = resource.l0BBuf.template GetBufferByByte(L0B_PINGPONG_BUF_LEN * i); + // Assign event ID for each stages + l1AEventList[i] = i; + l1BEventList[i] = i + STAGES; + l0AEventList[i] = i; + l0BEventList[i] = i + STAGES; + // The event id that needs to be set before the loop + AscendC::SetFlag(l1AEventList[i]); + AscendC::SetFlag(l1BEventList[i]); + AscendC::SetFlag(l0AEventList[i]); + AscendC::SetFlag(l0BEventList[i]); + } + l0CTensor = resource.l0CBuf.template GetBufferByByte(0); + } + // destroy function + CATLASS_DEVICE + ~BlockGemm() + { + for (uint32_t i = 0; i < STAGES; i++) { + AscendC::WaitFlag(l1AEventList[i]); + AscendC::WaitFlag(l1BEventList[i]); + AscendC::WaitFlag(l0AEventList[i]); + AscendC::WaitFlag(l0BEventList[i]); + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &gmA, LayoutA const &layoutA, + AscendC::GlobalTensor const &gmB, LayoutB const &layoutB, + AscendC::GlobalTensor const &gmC, LayoutC const &layoutC, + AscendC::GlobalTensor const &gmNextBlockA, + AscendC::GlobalTensor const &gmNextBlockB, GemmCoord const &actualShape, + GemmCoord const &actualShapeNext, bool isFirstBlock, bool hasNextBlock, uint32_t singleIdx) + { + uint32_t K = actualShape.k(); + uint32_t maxKPerBlock = L1TileShape::K; + uint32_t kLoops = CeilDiv(K, maxKPerBlock); + uint32_t kLoopsNext = CeilDiv(actualShapeNext.k(), maxKPerBlock); + uint32_t startTileIdx{0}; + if (ENABLE_SHUFFLE_K) { + startTileIdx = AscendC::GetBlockIdx(); + } + uint32_t firstTileIdx = startTileIdx % kLoops; + uint32_t firstTileIdxNext = startTileIdx % kLoopsNext; + uint32_t lastTileIdx = (startTileIdx + kLoops - 1) % kLoops; + uint32_t kGmActual = (firstTileIdx == kLoops - 1) ? (K - firstTileIdx * maxKPerBlock) : maxKPerBlock; + auto layoutAInL1 = LayoutAInL1::template MakeLayout(L1TileShape::M, L1TileShape::K); + auto layoutBInL1 = LayoutBInL1::template MakeLayout(L1TileShape::K, L1TileShape::N); + for (uint32_t kIdx = 0; kIdx < kLoops; kIdx++) { + uint32_t shuffleKIdx = (startTileIdx + kIdx) % kLoops; + if (shuffleKIdx == firstTileIdx && isFirstBlock) { + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kGmActual)); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kGmActual, actualShape.n())); + MatrixCoord gmTileAOffset{0, shuffleKIdx * maxKPerBlock}; + auto gmTileA = gmA[layoutA.GetOffset(gmTileAOffset)]; + MatrixCoord gmTileBOffset{shuffleKIdx * maxKPerBlock, 0}; + auto gmTileB = gmB[layoutB.GetOffset(gmTileBOffset)]; + AscendC::WaitFlag(l1AEventList[l1ListId]); + copyGmToL1A(l1ATensor[l1ListId], gmTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListId]); + AscendC::WaitFlag(l1BEventList[l1ListId]); + copyGmToL1B(l1BTensor[l1ListId], gmTileB, layoutBInL1, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListId]); + } + l1ListIdNext = 1 - l1ListId; + uint32_t kGmActualNext = 0; + if (shuffleKIdx != lastTileIdx) { + uint32_t shuffleKIdxNext = (startTileIdx + kIdx + 1) % kLoops; + kGmActualNext = (shuffleKIdxNext == kLoops - 1) ? (K - shuffleKIdxNext * maxKPerBlock) : maxKPerBlock; + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kGmActualNext)); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kGmActualNext, actualShape.n())); + MatrixCoord gmTileAOffset{0, shuffleKIdxNext * maxKPerBlock}; + auto gmTileA = gmA[layoutA.GetOffset(gmTileAOffset)]; + MatrixCoord gmTileBOffset{shuffleKIdxNext * maxKPerBlock, 0}; + auto gmTileB = gmB[layoutB.GetOffset(gmTileBOffset)]; + if (ENABLE_ABBA) { + if (shuffleKIdxNext % 2 == 1) { + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + copyGmToL1B(l1BTensor[l1ListIdNext], gmTileB, layoutBInL1, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + copyGmToL1A(l1ATensor[l1ListIdNext], gmTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + } else { + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + copyGmToL1A(l1ATensor[l1ListIdNext], gmTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + copyGmToL1B(l1BTensor[l1ListIdNext], gmTileB, layoutBInL1, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + } + } else { + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + copyGmToL1A(l1ATensor[l1ListIdNext], gmTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + copyGmToL1B(l1BTensor[l1ListIdNext], gmTileB, layoutBInL1, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + } + } + if (shuffleKIdx == lastTileIdx && hasNextBlock) { + kGmActualNext = (firstTileIdxNext == kLoopsNext - 1) + ? (actualShapeNext.k() - firstTileIdxNext * maxKPerBlock) + : maxKPerBlock; + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShapeNext.m(), kGmActualNext)); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kGmActualNext, actualShapeNext.n())); + MatrixCoord gmTileAOffset{0, firstTileIdxNext * maxKPerBlock}; + auto gmNextTileA = gmNextBlockA[layoutA.GetOffset(gmTileAOffset)]; + MatrixCoord gmTileBOffset{firstTileIdxNext * maxKPerBlock, 0}; + auto gmNextTileB = gmNextBlockB[layoutB.GetOffset(gmTileBOffset)]; + if (ENABLE_ABBA) { + if (shuffleKIdx % 2 == 0) { + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + copyGmToL1B(l1BTensor[l1ListIdNext], gmNextTileB, layoutBInL1, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + copyGmToL1A(l1ATensor[l1ListIdNext], gmNextTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + } else { + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + copyGmToL1A(l1ATensor[l1ListIdNext], gmNextTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + copyGmToL1B(l1BTensor[l1ListIdNext], gmNextTileB, layoutBInL1, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + } + } else { + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + copyGmToL1A(l1ATensor[l1ListIdNext], gmNextTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + copyGmToL1B(l1BTensor[l1ListIdNext], gmNextTileB, layoutBInL1, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + } + } + + uint32_t kL0TileSize = L0TileShape::K; + uint32_t kL0Loops = CeilDiv(kGmActual, kL0TileSize); + AscendC::WaitFlag(l1AEventList[l1ListId]); + AscendC::WaitFlag(l1BEventList[l1ListId]); + auto l1ATile = l1ATensor[l1ListId]; + auto l1BTile = l1BTensor[l1ListId]; + uint32_t mActual{0}; + uint32_t nActual{0}; + for (uint32_t kL0Idx = 0; kL0Idx < kL0Loops; kL0Idx++) { + uint32_t kL0Actual = (kL0Idx == kL0Loops - 1) ? (kGmActual - kL0Idx * kL0TileSize) : kL0TileSize; + LayoutAInL0 layoutAInL0 = LayoutAInL0::template MakeLayout(L1TileShape::M, kL0Actual); + LayoutBInL0 layoutBInL0 = LayoutBInL0::template MakeLayout(kL0Actual, L1TileShape::N); + uint32_t l1TileAOffset = layoutAInL1.GetOffset(MatrixCoord(0, kL0Idx * kL0TileSize)); + uint32_t l1TileBOffset = layoutBInL1.GetOffset(MatrixCoord(kL0Idx * kL0TileSize, 0)); + auto l1TileA = l1ATile[l1TileAOffset]; + auto l1TileB = l1BTile[l1TileBOffset]; + auto l0TileA = l0ATensor[l0ListId]; + auto l0TileB = l0BTensor[l0ListId]; + mActual = L1TileShape::M; + nActual = L1TileShape::N; + if (ENABLE_ABBA) { + if (shuffleKIdx % 2 == 0) { + if (kL0Idx % 2 == 0) { + AscendC::WaitFlag(l0BEventList[l0ListId]); + copyL1ToL0B(l0TileB, l1TileB, layoutBInL0, layoutBInL1); + AscendC::SetFlag(l0BEventList[l0ListId]); + AscendC::WaitFlag(l0AEventList[l0ListId]); + copyL1ToL0A(l0TileA, l1TileA, layoutAInL0, layoutAInL1); + AscendC::SetFlag(l0AEventList[l0ListId]); + } else { + AscendC::WaitFlag(l0AEventList[l0ListId]); + copyL1ToL0A(l0TileA, l1TileA, layoutAInL0, layoutAInL1); + AscendC::SetFlag(l0AEventList[l0ListId]); + AscendC::WaitFlag(l0BEventList[l0ListId]); + copyL1ToL0B(l0TileB, l1TileB, layoutBInL0, layoutBInL1); + AscendC::SetFlag(l0BEventList[l0ListId]); + } + } else { + if (kL0Idx % 2 == 0) { + AscendC::WaitFlag(l0AEventList[l0ListId]); + copyL1ToL0A(l0TileA, l1TileA, layoutAInL0, layoutAInL1); + AscendC::SetFlag(l0AEventList[l0ListId]); + AscendC::WaitFlag(l0BEventList[l0ListId]); + copyL1ToL0B(l0TileB, l1TileB, layoutBInL0, layoutBInL1); + AscendC::SetFlag(l0BEventList[l0ListId]); + } else { + AscendC::WaitFlag(l0BEventList[l0ListId]); + copyL1ToL0B(l0TileB, l1TileB, layoutBInL0, layoutBInL1); + AscendC::SetFlag(l0BEventList[l0ListId]); + AscendC::WaitFlag(l0AEventList[l0ListId]); + copyL1ToL0A(l0TileA, l1TileA, layoutAInL0, layoutAInL1); + AscendC::SetFlag(l0AEventList[l0ListId]); + } + } + } else { + AscendC::WaitFlag(l0AEventList[l0ListId]); + copyL1ToL0A(l0TileA, l1TileA, layoutAInL0, layoutAInL1); + AscendC::SetFlag(l0AEventList[l0ListId]); + AscendC::WaitFlag(l0BEventList[l0ListId]); + copyL1ToL0B(l0TileB, l1TileB, layoutBInL0, layoutBInL1); + AscendC::SetFlag(l0BEventList[l0ListId]); + } + if (kL0Idx == kL0Loops - 1) { + AscendC::SetFlag(l1AEventList[l1ListId]); + AscendC::SetFlag(l1BEventList[l1ListId]); + l1ListId = l1ListIdNext; + kGmActual = kGmActualNext; + } + AscendC::WaitFlag(l0BEventList[l0ListId]); + AscendC::WaitFlag(l0AEventList[l0ListId]); + tileMmad(l0CTensor[(singleIdx % l0CBlockNum) * BlockCnt], l0TileA, l0TileB, mActual, nActual, kL0Actual, + (kIdx == 0) && (kL0Idx == 0)); + AscendC::SetFlag(l0AEventList[l0ListId]); + AscendC::SetFlag(l0BEventList[l0ListId]); + l0ListId = 1 - l0ListId; + } + } + AscendC::SetFlag((int32_t)(singleIdx % l0CBlockNum)); + AscendC::WaitFlag((int32_t)(singleIdx % l0CBlockNum)); + auto layoutInL0X = LayoutCInL0::MakeLayoutInL0C(MakeCoord(L1TileShape::M, L1TileShape::N)); + LayoutC layoutBlock = layoutC.GetTileLayout(MakeCoord(actualShape.m(), actualShape.n())); + copyL0CToGm(gmC, l0CTensor[(singleIdx % l0CBlockNum) * BlockCnt], layoutBlock, layoutInL0X); + } + +private: + AscendC::LocalTensor l1ATensor[STAGES]; + AscendC::LocalTensor l1BTensor[STAGES]; + AscendC::LocalTensor l0ATensor[STAGES]; + AscendC::LocalTensor l0BTensor[STAGES]; + AscendC::LocalTensor l0CTensor; + // Multi-stage event id list + int32_t l1AEventList[STAGES]; + int32_t l1BEventList[STAGES]; + int32_t l0AEventList[STAGES]; + int32_t l0BEventList[STAGES]; + + // The id of current stage + uint32_t l1ListId{0}; + uint32_t l0ListId{0}; + uint32_t l1ListIdNext{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_BLOCK_BLOCK_MMAD_GEMM_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_mla_pv.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_mla_pv.hpp new file mode 100644 index 00000000..cbd0ab23 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_mla_pv.hpp @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_MLA_PV_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_MLA_PV_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +//////////////////////////////////////////////////////////////////// + +namespace Catlass::Gemm::Block { +//////////////////////////////////////////////////////////////////// + +template +struct BlockMmad { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2MLAPV; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; + static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; + static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; + static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; + static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES; + static constexpr uint32_t L0C_PINGPONG_BUF_SIZE = L0C_SIZE / STAGES; + static constexpr uint32_t BLOCK_SIZE = 16; + static constexpr uint32_t EMBED_SPLIT_SIZE = 128; + static constexpr uint32_t EMBED_SPLIT_LOOP = 4; + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + /// Construct + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + // Allocate L1 memory space + l1ATensor = resource.l1Buf.template GetBufferByByte(l1BufAddrStart + L1A_SIZE + L1B_SIZE * STAGES); + for (uint32_t i = 0; i < STAGES; i++) { + l1BTensor[i] = resource.l1Buf.template GetBufferByByte(l1BufAddrStart + L1A_SIZE + L1B_SIZE * i); + l0ATensor[i] = resource.l0ABuf.template GetBufferByByte(L0A_PINGPONG_BUF_SIZE * i); + l0BTensor[i] = resource.l0BBuf.template GetBufferByByte(L0B_PINGPONG_BUF_SIZE * i); + l0CTensor[i] = resource.l0CBuf.template GetBufferByByte(L0C_PINGPONG_BUF_SIZE * i); + } + } + + /// Destructor + CATLASS_DEVICE + ~BlockMmad() {} + + /// Perform a block-scoped matrix multiply-accumulate + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gA, AscendC::GlobalTensor gC, LayoutA layoutA, + LayoutB layoutB, LayoutC layoutC, GemmCoord actualShape, uint32_t &nIdx, + Arch::CrossCoreFlag softmaxReady) + { + uint32_t rowNum = actualShape.m(); + uint32_t vSeqTile = actualShape.k(); + uint32_t embed = actualShape.n(); + uint32_t embedSplitSize = EMBED_SPLIT_SIZE; + uint32_t embedSplitLoopV = EMBED_SPLIT_LOOP; + uint32_t rowNumRound = RoundUp(rowNum); + uint32_t embedSplitSizeRound = RoundUp(embedSplitSize); + uint32_t vSeqTileRound = RoundUp(vSeqTile); + uint32_t L1BPingPongFlag = (nIdx - 1) % 2; + uint32_t L0APingPongFlag = (nIdx - 1) % 2; + + for (uint32_t embedSplitIdx = 0; embedSplitIdx < embedSplitLoopV; embedSplitIdx++) { + uint32_t L0CPingPongFlag = (nIdx + embedSplitIdx) % 2; + uint32_t L0BPingPongFlag = (embedSplitIdx + 1) % 2; + AscendC::WaitFlag(L0BPingPongFlag + 2); + LayoutBInL1 layoutBInL1 = LayoutBInL1::template MakeLayout(vSeqTile, embed); + LayoutBInL0 layoutBInL0 = LayoutBInL0::template MakeLayout(vSeqTile, embedSplitSize); + // copy V from L1 to L0B + copyL1ToL0B(l0BTensor[L0BPingPongFlag], + l1BTensor[L1BPingPongFlag][embedSplitIdx * vSeqTileRound * EMBED_SPLIT_SIZE], layoutBInL0, + layoutBInL1); + if (embedSplitIdx == embedSplitLoopV - 1) { + AscendC::SetFlag(L1BPingPongFlag); + } + + if (embedSplitIdx == 0) { + LayoutAInL1 layoutAInL1 = LayoutAInL1::template MakeLayout(rowNum, vSeqTile); + Arch::CrossCoreWaitFlag(softmaxReady); + AscendC::WaitFlag(EVENT_ID7); + // copy P to L1 + copyGmToL1A(l1ATensor, gA, layoutAInL1, layoutA); + AscendC::SetFlag(EVENT_ID7); + AscendC::WaitFlag(EVENT_ID7); + // move p from l1 to l0a + AscendC::WaitFlag(L0APingPongFlag); + LayoutAInL0 layoutAInL0 = LayoutAInL0::template MakeLayout(rowNum, vSeqTile); + copyL1ToL0A(l0ATensor[L0APingPongFlag], l1ATensor, layoutAInL0, layoutAInL1); + AscendC::SetFlag(EVENT_ID7); + } + + AscendC::SetFlag(L0BPingPongFlag); + AscendC::WaitFlag(L0BPingPongFlag); + AscendC::WaitFlag(L0CPingPongFlag); + // mmad + tileMmad(l0CTensor[L0CPingPongFlag], l0ATensor[L0APingPongFlag], l0BTensor[L0BPingPongFlag], rowNumRound, + embedSplitSize, vSeqTile); + AscendC::SetFlag(L0BPingPongFlag + 2); + if (embedSplitIdx == embedSplitLoopV - 1) { + AscendC::SetFlag(L0APingPongFlag); + } + AscendC::SetFlag(L0CPingPongFlag); + + AscendC::WaitFlag(L0CPingPongFlag); + auto blockShape = MakeCoord(rowNum, embedSplitSize); + auto layoutInL0C = LayoutCInL0::MakeLayoutInL0C(blockShape); + auto layoutCSplitK = layoutC.GetTileLayout(MakeCoord(rowNumRound, embedSplitSizeRound)); + // copy Otmp to gm + copyL0CToGm(gC[embedSplitIdx * embedSplitSizeRound], l0CTensor[L0CPingPongFlag], layoutCSplitK, + layoutInL0C); + AscendC::SetFlag(L0CPingPongFlag); + } + } + +protected: + /// Data members + AscendC::LocalTensor l1ATensor; + AscendC::LocalTensor l1BTensor[STAGES]; + AscendC::LocalTensor l0ATensor[STAGES]; + AscendC::LocalTensor l0BTensor[STAGES]; + AscendC::LocalTensor l0CTensor[STAGES]; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +//////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_MLA_PV_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_mla_pv_tp1_spec.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_mla_pv_tp1_spec.hpp new file mode 100644 index 00000000..8cce0a58 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_mla_pv_tp1_spec.hpp @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_MLA_PV_TP1_SPEC_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_MLA_PV_TP1_SPEC_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +//////////////////////////////////////////////////////////////////// + +namespace Catlass::Gemm::Block { +//////////////////////////////////////////////////////////////////// + +template +struct BlockMmad { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2MLAPV; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; + static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; + static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; + static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; + static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES; + static constexpr uint32_t L0C_PINGPONG_BUF_SIZE = L0C_SIZE / STAGES; + static constexpr uint32_t UNIT_BLOCK_STACK_NUM = 4; + static constexpr uint32_t EMBED_SPLIT_LOOP = 4; + static constexpr uint32_t BLOCK_SIZE = 16; + static constexpr uint32_t EMBED_SPLIT_SIZE = 128; + static constexpr uint32_t L1_PV_ADDR_START = 311296; // reserved for Q(no db) and K(db) + static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::N * sizeof(ElementA); + static constexpr uint32_t L1B_SIZE = L1TileShape::N * EMBED_SPLIT_SIZE * sizeof(ElementB); + + /// Construct + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = L1_PV_ADDR_START) + { + // Allocate L1 memory space + for (uint32_t i = 0; i < STAGES; i++) { + l1ATensor[i] = + resource.l1Buf.template GetBufferByByte(l1BufAddrStart + L1B_SIZE * STAGES + L1A_SIZE * i); + l1BTensor[i] = resource.l1Buf.template GetBufferByByte(l1BufAddrStart + L1B_SIZE * i); + l0ATensor[i] = resource.l0ABuf.template GetBufferByByte(L0A_PINGPONG_BUF_SIZE * i); + l0BTensor[i] = resource.l0BBuf.template GetBufferByByte(L0B_PINGPONG_BUF_SIZE * i); + l0CTensor[i] = resource.l0CBuf.template GetBufferByByte(L0C_PINGPONG_BUF_SIZE * i); + } + } + + /// Destructor + CATLASS_DEVICE + ~BlockMmad() {} + + /// Perform a block-scoped matrix multiply-accumulate + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gA, AscendC::GlobalTensor gB, + AscendC::GlobalTensor gblockTable, AscendC::GlobalTensor gC, LayoutA layoutA, + LayoutB layoutB, LayoutC layoutC, GemmCoord actualShape, uint32_t &nIdx, uint32_t &nLoop, + uint32_t &blockSize, uint32_t kvSeqlen, Arch::CrossCoreFlag softmaxReady) + { + uint32_t rowNum = actualShape.m(); + uint32_t stackSeqTile = actualShape.k(); + uint32_t seqTile = blockSize; + uint32_t embed = actualShape.n(); + uint32_t embedSplitSize = EMBED_SPLIT_SIZE; + uint32_t embedSplitLoopV = EMBED_SPLIT_LOOP; + uint32_t rowNumRound = RoundUp(rowNum); + uint32_t embedSplitSizeRound = RoundUp(embedSplitSize); + uint32_t stackSeqTileRound = layoutA.stride(0); + uint32_t seqTileRound = RoundUp(seqTile); + + for (uint32_t embedSplitIdx = 0; embedSplitIdx < embedSplitLoopV; embedSplitIdx++) { + uint32_t L0CPingPongFlag = embedSplitIdx % 2; + for (uint32_t blockStackIdx = 0; + (blockStackIdx < UNIT_BLOCK_STACK_NUM) && ((nIdx + blockStackIdx) < (nLoop + UNIT_BLOCK_STACK_NUM)); + blockStackIdx++) { + uint32_t nIdxActual = nIdx + blockStackIdx; + if (nIdxActual == (nLoop + UNIT_BLOCK_STACK_NUM - 1)) { + seqTile = (kvSeqlen - (nIdxActual - UNIT_BLOCK_STACK_NUM) * blockSize); + } else { + seqTile = blockSize; + } + seqTileRound = RoundUp(seqTile); + uint32_t L1ABPingPongFlag = nIdxActual % 2; + uint32_t L0ABPingPongFlag = nIdxActual % 2; + uint32_t blockTableId = gblockTable.GetValue(nIdxActual - UNIT_BLOCK_STACK_NUM); + uint64_t kvOffset = (uint64_t)blockTableId * blockSize * embed + embedSplitIdx * embedSplitSize; + + AscendC::WaitFlag(L1ABPingPongFlag + 4); + auto layoutUnitBSplitN = layoutB.GetTileLayout(MakeCoord(seqTile, embedSplitSize)); + // copy V to L1 + LayoutBInL1 layoutUnitBSplitNInL1 = LayoutBInL1::template MakeLayout(seqTile, embedSplitSize); + copyGmToL1B(l1BTensor[L1ABPingPongFlag], gB[kvOffset], layoutUnitBSplitNInL1, layoutUnitBSplitN); + AscendC::SetFlag(L1ABPingPongFlag); + + AscendC::WaitFlag(L1ABPingPongFlag); + AscendC::WaitFlag(L0ABPingPongFlag + 2); + LayoutBInL0 layoutUnitBSplitNInL0 = LayoutBInL0::template MakeLayout(seqTile, embedSplitSize); + // copy V from L1 to L0B + copyL1ToL0B(l0BTensor[L0ABPingPongFlag], l1BTensor[L1ABPingPongFlag], layoutUnitBSplitNInL0, + layoutUnitBSplitNInL1); + AscendC::SetFlag(L1ABPingPongFlag + 4); + + if (embedSplitIdx == 0 && blockStackIdx == 0) { + Arch::CrossCoreWaitFlag(softmaxReady); + } + AscendC::WaitFlag(L1ABPingPongFlag + 6); + auto layoutASplitK = layoutA.GetTileLayout(MakeCoord(rowNum, seqTile)); + LayoutAInL1 layoutASplitKInL1 = LayoutAInL1::template MakeLayout(rowNum, seqTile); + // copy P to L1 + copyGmToL1A(l1ATensor[L1ABPingPongFlag], gA[blockStackIdx * blockSize], layoutASplitKInL1, + layoutASplitK); + AscendC::SetFlag(EVENT_ID7); + // copy P to l0a + AscendC::WaitFlag(EVENT_ID7); + AscendC::WaitFlag(L0ABPingPongFlag); + LayoutAInL0 layoutASplitKInL0 = LayoutAInL0::template MakeLayout(rowNum, seqTile); + copyL1ToL0A(l0ATensor[L0ABPingPongFlag], l1ATensor[L1ABPingPongFlag], layoutASplitKInL0, + layoutASplitKInL1); + AscendC::SetFlag(L1ABPingPongFlag + 6); + AscendC::SetFlag(L0ABPingPongFlag); + // mmad + AscendC::WaitFlag(L0ABPingPongFlag); + if (blockStackIdx == 0) { + AscendC::WaitFlag(L0CPingPongFlag); + } + tileMmad(l0CTensor[L0CPingPongFlag], l0ATensor[L0ABPingPongFlag], l0BTensor[L0ABPingPongFlag], + rowNumRound, embedSplitSize, seqTile, blockStackIdx == 0); + AscendC::SetFlag(L0ABPingPongFlag + 2); + AscendC::SetFlag(L0ABPingPongFlag); + } + // copy Otmp to gm + AscendC::SetFlag(L0CPingPongFlag); + AscendC::WaitFlag(L0CPingPongFlag); + auto blockShape = MakeCoord(rowNum, embedSplitSize); + auto layoutInL0C = LayoutCInL0::MakeLayoutInL0C(blockShape); + auto layoutCSplitN = layoutC.GetTileLayout(MakeCoord(rowNumRound, embedSplitSizeRound)); + copyL0CToGm(gC[embedSplitIdx * embedSplitSizeRound], l0CTensor[L0CPingPongFlag], layoutCSplitN, + layoutInL0C); + AscendC::SetFlag(L0CPingPongFlag); + } + } + +protected: + /// Data members + AscendC::LocalTensor l1ATensor[STAGES]; + AscendC::LocalTensor l1BTensor[STAGES]; + AscendC::LocalTensor l0ATensor[STAGES]; + AscendC::LocalTensor l0BTensor[STAGES]; + AscendC::LocalTensor l0CTensor[STAGES]; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +//////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_MLA_PV_TP1_SPEC_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_mla_qk.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_mla_qk.hpp new file mode 100644 index 00000000..ad4a3068 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_mla_qk.hpp @@ -0,0 +1,215 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_MLA_QK_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_MLA_QK_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +//////////////////////////////////////////////////////////////////// + +namespace Catlass::Gemm::Block { +//////////////////////////////////////////////////////////////////// + +template +struct BlockMmad { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2MLAQK; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; + static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; + static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; + static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; + static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES; + static constexpr uint32_t L0C_PINGPONG_BUF_SIZE = L0C_SIZE / STAGES; + static constexpr uint32_t BLOCK_SIZE = 16; + static constexpr uint32_t EMBED_SPLIT_SIZE = 128; + static constexpr uint32_t EMBED_SPLIT_LOOP = 5; + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + /// Construct + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + // Allocate L1 memory space + l1ATensor = resource.l1Buf.template GetBufferByByte(l1BufAddrStart); + for (uint32_t i = 0; i < STAGES; i++) { + l1BTensor[i] = resource.l1Buf.template GetBufferByByte(l1BufAddrStart + L1A_SIZE + L1B_SIZE * i); + l0ATensor[i] = resource.l0ABuf.template GetBufferByByte(L0A_PINGPONG_BUF_SIZE * i); + l0BTensor[i] = resource.l0BBuf.template GetBufferByByte(L0B_PINGPONG_BUF_SIZE * i); + l0CTensor[i] = resource.l0CBuf.template GetBufferByByte(L0C_PINGPONG_BUF_SIZE * i); + } + } + + /// Destructor + CATLASS_DEVICE + ~BlockMmad() {} + + /// Perform a block-scoped matrix multiply-accumulate + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gA, AscendC::GlobalTensor gARope, + AscendC::GlobalTensor gB, AscendC::GlobalTensor gBRope, + AscendC::GlobalTensor gC, LayoutA layoutA, LayoutA layoutARope, LayoutB layoutB, + LayoutB layoutBRope, LayoutC layoutC, GemmCoord actualShape, MatrixCoord qShapeSingleNd, + uint32_t &qHeads, uint32_t &nIdx) + { + uint32_t rowNum = actualShape.m(); + uint32_t kSeqTile = actualShape.n(); + uint32_t embed = layoutB.shape(0); + uint32_t embedRope = layoutBRope.shape(0); + uint32_t embedCat = actualShape.k(); + uint32_t embedSplitSize = EMBED_SPLIT_SIZE; + uint32_t embedSplitLoopK = EMBED_SPLIT_LOOP; + uint32_t curHeadNum = qShapeSingleNd.row(); + uint32_t tokenNumPerHead = rowNum / curHeadNum; + uint32_t kSeqTileRound = RoundUp(kSeqTile); + uint32_t rowNumRound = RoundUp(rowNum); + uint32_t l1KvPingPongFlag = nIdx % 2; + + if (nIdx == 0) { + // copy Q to L1 + auto layoutASingleNd = layoutA.GetTileLayout(MakeCoord(curHeadNum, embed)); + LayoutAInL1 layoutAInL1 = LayoutAInL1::template MakeLayout(rowNum, embed); + copyGmToL1A(l1ATensor, gA, layoutAInL1, layoutASingleNd, tokenNumPerHead, qHeads * embed, tokenNumPerHead, + BLOCK_SIZE, rowNumRound); + + // copy QRope to L1 + auto layoutARopeSingleNd = layoutARope.GetTileLayout(MakeCoord(curHeadNum, embedRope)); + LayoutAInL1 layoutARopeInL1 = LayoutAInL1::template MakeLayout(rowNum, embedRope); + copyGmToL1A(l1ATensor[rowNumRound * embed], gARope, layoutARopeInL1, layoutARopeSingleNd, tokenNumPerHead, + qHeads * embedRope, tokenNumPerHead, BLOCK_SIZE, rowNumRound); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + + AscendC::WaitFlag(l1KvPingPongFlag); + // copy K to L1 + LayoutBInL1 layoutBInL1 = LayoutBInL1::template MakeLayout(embed, kSeqTile); + copyGmToL1B(l1BTensor[l1KvPingPongFlag], gB, layoutBInL1, layoutB); + AscendC::SetFlag(l1KvPingPongFlag); + + AscendC::WaitFlag(l1KvPingPongFlag + 2); + // copy KRope to L1 + LayoutBInL1 layoutBRopeInL1 = LayoutBInL1::template MakeLayout(embedRope, kSeqTile); + copyGmToL1B(l1BTensor[l1KvPingPongFlag][kSeqTileRound * embed], gBRope, layoutBRopeInL1, layoutBRope); + AscendC::SetFlag(l1KvPingPongFlag + 2); + + for (uint32_t embedSplitIdx = 0; embedSplitIdx < embedSplitLoopK; embedSplitIdx++) { + uint32_t l0ABPingPongFlag = embedSplitIdx % 2; + if (embedSplitIdx == embedSplitLoopK - 1) { + embedSplitSize = embedRope; + } + // copy Q from L1 to l0a + AscendC::WaitFlag(l0ABPingPongFlag); + LayoutAInL1 layoutACatInL1 = LayoutAInL1::template MakeLayout(rowNum, embedCat); + LayoutAInL0 layoutACatInL0 = LayoutAInL0::template MakeLayout(rowNum, embedSplitSize); + copyL1ToL0A(l0ATensor[l0ABPingPongFlag], l1ATensor[embedSplitIdx * rowNumRound * EMBED_SPLIT_SIZE], + layoutACatInL0, layoutACatInL1); + AscendC::SetFlag(l0ABPingPongFlag); + + if (embedSplitIdx == 0) { + AscendC::WaitFlag(l1KvPingPongFlag); + } else if (embedSplitIdx == embedSplitLoopK - 1) { + AscendC::WaitFlag(l1KvPingPongFlag + 2); + } + AscendC::WaitFlag(l0ABPingPongFlag + 2); + // copy K from L1 to l0b + LayoutBInL1 layoutBCatInL1 = LayoutBInL1::template MakeLayout(embedCat, kSeqTile); + LayoutBInL0 layoutBCatInL0 = LayoutBInL0::template MakeLayout(embedSplitSize, kSeqTile); + copyL1ToL0B(l0BTensor[l0ABPingPongFlag], + l1BTensor[l1KvPingPongFlag][embedSplitIdx * kSeqTileRound * EMBED_SPLIT_SIZE], layoutBCatInL0, + layoutBCatInL1); + if (embedSplitIdx == embedSplitLoopK - 1) { + AscendC::SetFlag(l1KvPingPongFlag + 2); + } + AscendC::SetFlag(l0ABPingPongFlag + 2); + // mmad + AscendC::WaitFlag(l0ABPingPongFlag); + AscendC::WaitFlag(l0ABPingPongFlag + 2); + if (embedSplitIdx == 0) { + AscendC::WaitFlag(l1KvPingPongFlag); + } + // mmad + tileMmad(l0CTensor[l1KvPingPongFlag], l0ATensor[l0ABPingPongFlag], l0BTensor[l0ABPingPongFlag], rowNumRound, + kSeqTile, embedSplitSize, embedSplitIdx == 0); + AscendC::SetFlag(l0ABPingPongFlag); + AscendC::SetFlag(l0ABPingPongFlag + 2); + } + // copy S to gm + AscendC::SetFlag(l1KvPingPongFlag); + AscendC::WaitFlag(l1KvPingPongFlag); + auto blockShape = MakeCoord(rowNum, kSeqTile); + auto layoutInL0C = LayoutCInL0::MakeLayoutInL0C(blockShape); + // copy L0C to gm + copyL0CToGm(gC, l0CTensor[l1KvPingPongFlag], layoutC, layoutInL0C); + AscendC::SetFlag(l1KvPingPongFlag); + } + +protected: + /// Data members + AscendC::LocalTensor l1ATensor; + AscendC::LocalTensor l1BTensor[STAGES]; + AscendC::LocalTensor l0ATensor[STAGES]; + AscendC::LocalTensor l0BTensor[STAGES]; + AscendC::LocalTensor l0CTensor[STAGES]; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +//////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_MLA_QK_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_mla_qk_tp1_spec.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_mla_qk_tp1_spec.hpp new file mode 100644 index 00000000..e802e7ab --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_mla_qk_tp1_spec.hpp @@ -0,0 +1,244 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_MLA_QK_TP1_SPEC_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_MLA_QK_TP1_SPEC_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +//////////////////////////////////////////////////////////////////// + +namespace Catlass::Gemm::Block { +//////////////////////////////////////////////////////////////////// + +template +struct BlockMmad { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2MLAQKTp1Spec; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; + static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; + static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; + static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; + static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES; + static constexpr uint32_t L0C_PINGPONG_BUF_SIZE = L0C_SIZE / STAGES; + static constexpr uint32_t UNIT_BLOCK_STACK_NUM = 4; + static constexpr uint32_t BLOCK_SIZE = 16; + static constexpr uint32_t EMBED_SPLIT_SIZE = 128; + static constexpr uint32_t EMBED_ROPE = 64; + static constexpr uint32_t GM_L1_EMBED_SPLIT_SIZE = 256; + static constexpr uint32_t EMBED_SPLIT_LOOP = 5; + static constexpr uint32_t L1B_SIZE = L1TileShape::N * GM_L1_EMBED_SPLIT_SIZE * sizeof(ElementB); + static constexpr uint32_t L1BROPE_SIZE = L1TileShape::N * EMBED_ROPE * sizeof(ElementB); + static constexpr uint32_t L1B_ROPE_START = L1TileShape::N * GM_L1_EMBED_SPLIT_SIZE; + + /// Construct + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + // Allocate L1 memory space + l1ATensor = resource.l1Buf.template GetBufferByByte(l1BufAddrStart); + for (uint32_t i = 0; i < STAGES; i++) { + l1BTensor[i] = resource.l1Buf.template GetBufferByByte(l1BufAddrStart + L1A_SIZE + L1B_SIZE * i); + l1BRopeTensor[i] = resource.l1Buf.template GetBufferByByte(l1BufAddrStart + L1A_SIZE + + L1B_SIZE * STAGES + L1BROPE_SIZE * i); + l0ATensor[i] = resource.l0ABuf.template GetBufferByByte(L0A_PINGPONG_BUF_SIZE * i); + l0BTensor[i] = resource.l0BBuf.template GetBufferByByte(L0B_PINGPONG_BUF_SIZE * i); + l0CTensor[i] = resource.l0CBuf.template GetBufferByByte(L0C_PINGPONG_BUF_SIZE * i); + } + } + + /// Destructor + CATLASS_DEVICE + ~BlockMmad() {} + + /// Perform a block-scoped matrix multiply-accumulate + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor gA, AscendC::GlobalTensor gARope, + AscendC::GlobalTensor gB, AscendC::GlobalTensor gBRope, + AscendC::GlobalTensor gblockTable, AscendC::GlobalTensor gC, LayoutA layoutA, + LayoutA layoutARope, LayoutB layoutB, LayoutB layoutBRope, LayoutC layoutC, GemmCoord actualShape, + uint32_t &nIdx, uint32_t &nLoop, uint32_t &blockSize, uint32_t kvSeqlen) + { + uint32_t rowNum = actualShape.m(); + uint32_t stackSeqTile = actualShape.n(); + uint32_t seqTile = blockSize; + uint32_t embed = layoutA.shape(1); + uint32_t embedRope = layoutARope.shape(1); + uint32_t embedCat = actualShape.k(); + uint32_t embedSplitGm2L1 = GM_L1_EMBED_SPLIT_SIZE; + uint32_t stackSeqTileRound = layoutC.shape(1); + uint32_t rowNumRound = layoutC.shape(0); + uint32_t seqTileRound = RoundUp(seqTile); + + if (nIdx == 0) { + // copy Q to L1 + LayoutAInL1 layoutAInL1 = LayoutAInL1::template MakeLayout(rowNum, embed); + copyGmToL1A(l1ATensor, gA, layoutAInL1, layoutA); + + // copy QRope to L1 + LayoutAInL1 layoutARopeInL1 = LayoutAInL1::template MakeLayout(rowNum, embedRope); + copyGmToL1A(l1ATensor[rowNumRound * embed], gARope, layoutARopeInL1, layoutARope); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + + for (uint32_t blockStackIdx = 0; (blockStackIdx < UNIT_BLOCK_STACK_NUM) && ((nIdx + blockStackIdx) < nLoop); + blockStackIdx++) { + uint32_t nIdxActual = nIdx + blockStackIdx; + uint32_t L0CPingPongFlag = nIdxActual % 2; + uint32_t L1BRopePingPongFlag = nIdxActual % 2; + if (nIdxActual == (nLoop - 1)) { + seqTile = (kvSeqlen - nIdxActual * blockSize); + seqTileRound = RoundUp(seqTile); + } + uint32_t blockTableId = gblockTable.GetValue(nIdxActual); + uint64_t kvOffset = (uint64_t)blockTableId * blockSize * embed; + uint64_t kvOffsetRope = (uint64_t)blockTableId * blockSize * embedRope; + uint64_t l1bSplitOffset = 0; + uint32_t embedSplitSize = EMBED_SPLIT_SIZE; + uint32_t embedSplitLoopK = EMBED_SPLIT_LOOP; + for (uint32_t embedSplitIdx = 0; embedSplitIdx < embedSplitLoopK; embedSplitIdx++) { + uint32_t L0ABPingPongFlag = (blockStackIdx + embedSplitIdx) % 2; + uint32_t innerSplitIdxL12L0 = embedSplitIdx % 2; + uint32_t L1BPingPongFlag = embedSplitIdx / 2; + if (embedSplitIdx == 4) { + embedSplitSize = embedRope; + } + // copy Q to l0a + AscendC::WaitFlag(L0ABPingPongFlag); + LayoutAInL1 layoutACatSplitKInL1 = LayoutAInL1::template MakeLayout(rowNum, embedSplitSize); + LayoutAInL0 layoutACatSplitKInL0 = LayoutAInL0::template MakeLayout(rowNum, embedSplitSize); + copyL1ToL0A(l0ATensor[L0ABPingPongFlag], l1ATensor[embedSplitIdx * rowNumRound * EMBED_SPLIT_SIZE], + layoutACatSplitKInL0, layoutACatSplitKInL1); + AscendC::SetFlag(L0ABPingPongFlag); + + if (embedSplitIdx == 0 || embedSplitIdx == 2) { + // copy K to l1b + AscendC::WaitFlag(L1BPingPongFlag); + auto layoutUnitBSplitK = layoutB.GetTileLayout(MakeCoord(embedSplitGm2L1, seqTile)); + LayoutBInL1 layoutUnitBSplitKInL1 = + LayoutBInL1::template MakeLayout(embedSplitGm2L1, seqTile); + copyGmToL1B(l1BTensor[L1BPingPongFlag], gB[kvOffset + embedSplitIdx * embedSplitSize], + layoutUnitBSplitKInL1, layoutUnitBSplitK); + AscendC::SetFlag(L1BPingPongFlag); + AscendC::WaitFlag(L1BPingPongFlag); + } else if (embedSplitIdx == 4) { + // copy KRope to L1 + AscendC::WaitFlag(L1BRopePingPongFlag + 2); + auto layoutUnitBRope = layoutBRope.GetTileLayout(MakeCoord(embedRope, seqTile)); + LayoutBInL1 layoutBRopeInL1 = LayoutBInL1::template MakeLayout(embedRope, seqTile); + copyGmToL1B(l1BRopeTensor[L1BRopePingPongFlag], gBRope[kvOffsetRope], layoutBRopeInL1, + layoutUnitBRope); + AscendC::SetFlag(L1BRopePingPongFlag); + AscendC::WaitFlag(L1BRopePingPongFlag); + } + // copy K to l0b + AscendC::WaitFlag(L0ABPingPongFlag + 2); + LayoutBInL1 layoutBCatSplitKInL1 = LayoutBInL1::template MakeLayout(embedSplitSize, seqTile); + LayoutBInL0 layoutBCatSplitKInL0 = LayoutBInL0::template MakeLayout(embedSplitSize, seqTile); + if (embedSplitIdx != 4) { + copyL1ToL0B(l0BTensor[L0ABPingPongFlag], + l1BTensor[L1BPingPongFlag][innerSplitIdxL12L0 * EMBED_SPLIT_SIZE * seqTileRound], + layoutBCatSplitKInL0, layoutBCatSplitKInL1); + } else { + copyL1ToL0B(l0BTensor[L0ABPingPongFlag], l1BRopeTensor[L1BRopePingPongFlag], layoutBCatSplitKInL0, + layoutBCatSplitKInL1); + } + + if (embedSplitIdx == 1 || embedSplitIdx == 3) { + AscendC::SetFlag(L1BPingPongFlag); + } else if (embedSplitIdx == 4) { + AscendC::SetFlag(L1BRopePingPongFlag + 2); + } + AscendC::SetFlag(L0ABPingPongFlag + 2); + + AscendC::WaitFlag(L0ABPingPongFlag); + AscendC::WaitFlag(L0ABPingPongFlag + 2); + if (embedSplitIdx == 0) { + AscendC::WaitFlag(L0CPingPongFlag); + } + // mmad + tileMmad(l0CTensor[L0CPingPongFlag], l0ATensor[L0ABPingPongFlag], l0BTensor[L0ABPingPongFlag], + rowNumRound, seqTile, embedSplitSize, embedSplitIdx == 0); + AscendC::SetFlag(L0ABPingPongFlag); + AscendC::SetFlag(L0ABPingPongFlag + 2); + } + // copy S to gm + AscendC::SetFlag(L0CPingPongFlag); + AscendC::WaitFlag(L0CPingPongFlag); + auto blockShape = MakeCoord(rowNum, seqTile); + auto layoutInL0C = LayoutCInL0::MakeLayoutInL0C(blockShape); + auto layoutCSplitN = layoutC.GetTileLayout(MakeCoord(rowNumRound, seqTileRound)); + // copy L0C to gm + copyL0CToGm(gC[blockStackIdx * blockSize], l0CTensor[L0CPingPongFlag], layoutCSplitN, layoutInL0C); + AscendC::SetFlag(L0CPingPongFlag); + } + } + +protected: + /// Data members + AscendC::LocalTensor l1ATensor; + AscendC::LocalTensor l1BTensor[STAGES]; + AscendC::LocalTensor l1BRopeTensor[STAGES]; + AscendC::LocalTensor l0ATensor[STAGES]; + AscendC::LocalTensor l0BTensor[STAGES]; + AscendC::LocalTensor l0CTensor[STAGES]; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +//////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_MLA_QK_TP1_SPEC_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_pingpong.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_pingpong.hpp new file mode 100644 index 00000000..f944cb78 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_pingpong.hpp @@ -0,0 +1,335 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_PINGPONG_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_PINGPONG_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +namespace Catlass::Gemm::Block { + +template +struct BlockMmad, L1TileShape_, L0TileShape_, AType_, BType_, CType_, BiasType_, + TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2Pingpong; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; + static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; + static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; + static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; + static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES; + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + // Check L1TileShape + static_assert((L1A_SIZE * STAGES + L1B_SIZE * STAGES) <= ArchTag::L1_SIZE, "L1TileShape exceeding the L1 space!"); + + // Check L0TileShape + static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA); + static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB); + static constexpr uint32_t L0C_TILE_SIZE = L0TileShape::M * L0TileShape::N * sizeof(ElementAccumulator); + static_assert((L0A_TILE_SIZE * STAGES) <= L0A_SIZE, "L0TileShape exceeding the L0A space!"); + static_assert((L0B_TILE_SIZE * STAGES) <= L0B_SIZE, "L0TileShape exceeding the L0B space!"); + static_assert(L0C_TILE_SIZE <= L0C_SIZE, "L0TileShape exceeding the L0C space!"); + + static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N, + "The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet"); + static_assert(L0TileShape::K <= L1TileShape::K, "L0TileShape::K cannot exceed L1TileShape::K"); + + /// Construct + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + uint32_t l1AOffset = l1BufAddrStart; + uint32_t l1BOffset = l1BufAddrStart + L1A_SIZE * STAGES; + // Init buffers + for (uint32_t i = 0; i < STAGES; i++) { + // Assign L1/L0A/L0B space for each stages + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1A_SIZE * i); + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_SIZE * i); + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_PINGPONG_BUF_SIZE * i); + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_PINGPONG_BUF_SIZE * i); + + // Assign event ID for each stages + l1AEventList[i] = i; + l1BEventList[i] = i + STAGES; + l0AEventList[i] = i; + l0BEventList[i] = i + STAGES; + + // The event id that needs to be set before the loop + AscendC::SetFlag(l1AEventList[i]); + AscendC::SetFlag(l1BEventList[i]); + AscendC::SetFlag(l0AEventList[i]); + AscendC::SetFlag(l0BEventList[i]); + } + l0CTensor = resource.l0CBuf.template GetBufferByByte(0); + AscendC::SetFlag(EVENT_ID0); + } + + /// Destructor + CATLASS_DEVICE + ~BlockMmad() + { + for (uint32_t i = 0; i < STAGES; i++) { + AscendC::WaitFlag(l1AEventList[i]); + AscendC::WaitFlag(l1BEventList[i]); + AscendC::WaitFlag(l0AEventList[i]); + AscendC::WaitFlag(l0BEventList[i]); + } + AscendC::WaitFlag(EVENT_ID0); + } + + /// Perform a block-scoped matrix multiply-accumulate + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &gmA, LayoutA const &layoutA, + AscendC::GlobalTensor const &gmB, LayoutB const &layoutB, + AscendC::GlobalTensor const &gmC, LayoutC const &layoutC, GemmCoord const &actualShape) + { + uint32_t mRound = RoundUp(actualShape.m()); + uint32_t nRound = RoundUp(actualShape.n()); + + auto layoutAInL1 = LayoutAInL1::template MakeLayout(L1TileShape::M, L1TileShape::K); + auto layoutBInL1 = LayoutBInL1::template MakeLayout(L1TileShape::K, L1TileShape::N); + auto layoutInL0C = LayoutCInL0::MakeLayoutInL0C(MakeCoord(mRound, nRound)); + + uint32_t kActual = min(actualShape.k(), L1TileShape::K); + + // load first matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListId]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); + copyGmToL1A(l1ATensorList[l1ListId], gmA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListId]); + + // load first matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListId]); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); + copyGmToL1B(l1BTensorList[l1ListId], gmB, layoutBInL1, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListId]); + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::WaitFlag(EVENT_ID0); + } + + uint32_t mPartLoop = CeilDiv(mRound); + uint32_t nPartLoop = CeilDiv(nRound); + + // main loop + uint32_t kTileCount = CeilDiv(actualShape.k()); + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; kLoopIdx++) { + uint32_t l1ListIdNext = (l1ListId + 1 < STAGES) ? (l1ListId + 1) : 0; + uint32_t kActualNext{0}; + // preload next tile from GM to L1 + if (kLoopIdx < kTileCount - 1) { + uint32_t kLoopIdxNext = kLoopIdx + 1; + kActualNext = (kLoopIdxNext < kTileCount - 1) ? L1TileShape::K + : (actualShape.k() - kLoopIdxNext * L1TileShape::K); + + // Get L1 tensor for next stage + auto l1ATensor = l1ATensorList[l1ListIdNext]; + auto l1BTensor = l1BTensorList[l1ListIdNext]; + // Get GM tile for next stage + MatrixCoord gmTileAOffset{0, kLoopIdxNext * L1TileShape::K}; + MatrixCoord gmTileBOffset{kLoopIdxNext * L1TileShape::K, 0}; + auto gmTileA = gmA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTileB = gmB[layoutB.GetOffset(gmTileBOffset)]; + + // load next matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActualNext)); + copyGmToL1A(l1ATensor, gmTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + + // load next matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + layoutTileB = layoutB.GetTileLayout(MakeCoord(kActualNext, actualShape.n())); + copyGmToL1B(l1BTensor, gmTileB, layoutBInL1, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + } + + // Get L1 tensor for current stage + auto l1ATensor = l1ATensorList[l1ListId]; + auto l1BTensor = l1BTensorList[l1ListId]; + + // Get the loop nums on L0 + uint32_t kPartLoop = CeilDiv(kActual); + + for (int mPartIdx = 0; mPartIdx < mPartLoop; mPartIdx++) { + uint32_t mPartActual = + (mPartIdx < mPartLoop - 1) ? L0TileShape::M : (mRound - mPartIdx * L0TileShape::M); + + for (int kPartIdx = 0; kPartIdx < kPartLoop; kPartIdx++) { + uint32_t kPartActual = + (kPartIdx < kPartLoop - 1) ? L0TileShape::K : (kActual - kPartIdx * L0TileShape::K); + + // Locate the current tile on L0A + auto l0ATile = l0ATensorList[l0AListId]; + LayoutAInL0 layoutAInL0 = LayoutAInL0::template MakeLayout(mPartActual, kPartActual); + // Locate the current tile of matrix A on L1 + MatrixCoord l1AOffset{mPartIdx * L0TileShape::M, kPartIdx * L0TileShape::K}; + auto l1ATile = l1ATensor[layoutAInL1.GetOffset(l1AOffset)]; + + AscendC::WaitFlag(l0AEventList[l0AListId]); + if ((mPartIdx == 0) && (kPartIdx == 0)) { + AscendC::WaitFlag(l1AEventList[l1ListId]); + } + + // Load current tile from L1 to L0A + copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, layoutAInL1); + + if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) { + AscendC::SetFlag(l1AEventList[l1ListId]); + } + + for (int nPartIdx = 0; nPartIdx < nPartLoop; nPartIdx++) { + uint32_t nPartActual = + (nPartIdx < nPartLoop - 1) ? L0TileShape::N : (nRound - nPartIdx * L0TileShape::N); + + // Locate the current tile on L0B + auto l0BTile = l0BTensorList[l0BListId]; + LayoutBInL0 layoutBInL0 = LayoutBInL0::template MakeLayout(kPartActual, nPartActual); + // Locate the current tile of matrix B on L1 + MatrixCoord l1BOffset{kPartIdx * L0TileShape::K, nPartIdx * L0TileShape::N}; + auto l1BTile = l1BTensor[layoutBInL1.GetOffset(l1BOffset)]; + + // Wait for mmad finished + AscendC::WaitFlag(l0BEventList[l0BListId]); + // If the current tile is the first one on the k&n axis, wait for loading matrix B from GM to L1 + if ((kPartIdx == 0) && (nPartIdx == 0)) { + AscendC::WaitFlag(l1BEventList[l1ListId]); + } + + // Load current tile from L1 to L0B + copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, layoutBInL1); + + // If the current tile is the last one on the k&n axis, notify to load matrix B from GM to L1 + if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + AscendC::SetFlag(l1BEventList[l1ListId]); + } + // Notify to do mmad + AscendC::SetFlag(EVENT_ID0); + + // Locate the current tile on L0C + MatrixCoord l0COffset{mPartIdx * L0TileShape::M, nPartIdx * L0TileShape::N}; + auto l0CTile = l0CTensor[layoutInL0C.GetOffset(l0COffset)]; + + // Compute the matrix multiplication on L0A and L0B and write the result to the accumulator + // Wait for loading L0B + AscendC::WaitFlag(EVENT_ID0); + + // If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0 + bool initC = ((kLoopIdx == 0) && (kPartIdx == 0)); + // If the unit flag is enabled, the unit flag is set according to the calculation progress + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if ((kLoopIdx == kTileCount - 1) && (mPartIdx == mPartLoop - 1) && + (kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + // Perform calculation operations + tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); + + // Notify to move the next L0B tile + AscendC::SetFlag(l0BEventList[l0BListId]); + l0BListId = (l0BListId + 1 < STAGES) ? (l0BListId + 1) : 0; + } + AscendC::SetFlag(l0AEventList[l0AListId]); + l0AListId = (l0AListId + 1 < STAGES) ? (l0AListId + 1) : 0; + } + } + l1ListId = l1ListIdNext; + kActual = kActualNext; + } + + // copy block out + LayoutC layoutBlock = layoutC.GetTileLayout(actualShape.GetCoordMN()); + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + copyL0CToGm(gmC, l0CTensor, layoutBlock, layoutInL0C); + AscendC::SetFlag(EVENT_ID0); + } else { + copyL0CToGm(gmC, l0CTensor, layoutBlock, layoutInL0C, 0b11); + } + } + +protected: + // Multi-stage tensors list + AscendC::LocalTensor l1ATensorList[STAGES]; + AscendC::LocalTensor l1BTensorList[STAGES]; + AscendC::LocalTensor l0ATensorList[STAGES]; + AscendC::LocalTensor l0BTensorList[STAGES]; + AscendC::LocalTensor l0CTensor; + + // Multi-stage event id list + int32_t l1AEventList[STAGES]; + int32_t l1BEventList[STAGES]; + int32_t l0AEventList[STAGES]; + int32_t l0BEventList[STAGES]; + + // The id of current stage + uint32_t l1ListId{0}; + uint32_t l0AListId{0}; + uint32_t l0BListId{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PINGPONG_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_pingpong_bias.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_pingpong_bias.hpp new file mode 100644 index 00000000..a8827bcd --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_pingpong_bias.hpp @@ -0,0 +1,364 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_PINGPONG_BIAS_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_PINGPONG_BIAS_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +namespace Catlass::Gemm::Block { + +template +struct BlockMmad, L1TileShape_, L0TileShape_, AType_, BType_, CType_, + BiasType_, TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2PingpongBias; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementBias = typename BiasType_::Element; + using LayoutBias = typename BiasType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyGmToL1Bias = typename TileCopy_::CopyGmToL1Bias; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using CopyL1ToBT = typename TileCopy_::CopyL1ToBT; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + static constexpr uint32_t L1BIAS_SIZE = L1TileShape::N * sizeof(ElementBias); + static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; + static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; + static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; + static constexpr uint32_t BT_SIZE = ArchTag::BIAS_SIZE; + static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; + static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES; + static constexpr uint32_t BIAS_BUF_SIZE = L0TileShape::N * sizeof(ElementAccumulator); + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + // Check L1TileShape + static_assert((L1A_SIZE * STAGES + L1B_SIZE * STAGES + L1BIAS_SIZE) <= ArchTag::L1_SIZE, + "L1TileShape exceeding the L1 space!"); + + // Check L0TileShape + static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA); + static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB); + static constexpr uint32_t L0C_TILE_SIZE = L0TileShape::M * L0TileShape::N * sizeof(ElementAccumulator); + static_assert((L0A_TILE_SIZE * STAGES) <= L0A_SIZE, "L0TileShape exceeding the L0A space!"); + static_assert((L0B_TILE_SIZE * STAGES) <= L0B_SIZE, "L0TileShape exceeding the L0B space!"); + static_assert(L0C_TILE_SIZE <= L0C_SIZE, "L0TileShape exceeding the L0C space!"); + static_assert(BIAS_BUF_SIZE <= BT_SIZE, "BIAS_BUF_SIZE exceeding the BT space! Reduce L0TileShape::N"); + + static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N, + "The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet"); + static_assert(L0TileShape::K <= L1TileShape::K, "L0TileShape::K cannot exceed L1TileShape::K"); + + /// Construct + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + uint32_t l1AOffset = l1BufAddrStart; + uint32_t l1BOffset = l1BufAddrStart + L1A_SIZE * STAGES; + uint32_t l1BiasOffset = l1BufAddrStart + L1A_SIZE * STAGES + L1B_SIZE * STAGES; + // Init buffers + for (uint32_t i = 0; i < STAGES; i++) { + // Assign L1/L0A/L0B space for each stages + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1A_SIZE * i); + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_SIZE * i); + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_PINGPONG_BUF_SIZE * i); + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_PINGPONG_BUF_SIZE * i); + + // Assign event ID for each stages + l1AEventList[i] = i; + l1BEventList[i] = i + STAGES; + l0AEventList[i] = i; + l0BEventList[i] = i + STAGES; + + // The event id that needs to be set before the loop + AscendC::SetFlag(l1AEventList[i]); + AscendC::SetFlag(l1BEventList[i]); + AscendC::SetFlag(l0AEventList[i]); + AscendC::SetFlag(l0BEventList[i]); + } + l0CTensor = resource.l0CBuf.template GetBufferByByte(0); + l1BiasTensor = resource.l1Buf.template GetBufferByByte(l1BiasOffset); + l0BiasTensor = resource.btBuf.template GetBufferByByte(0); + AscendC::SetFlag(EVENT_ID0); + } + + /// Destructor + CATLASS_DEVICE + ~BlockMmad() + { + for (uint32_t i = 0; i < STAGES; i++) { + AscendC::WaitFlag(l1AEventList[i]); + AscendC::WaitFlag(l1BEventList[i]); + AscendC::WaitFlag(l0AEventList[i]); + AscendC::WaitFlag(l0BEventList[i]); + } + AscendC::WaitFlag(EVENT_ID0); + } + + /// Perform a block-scoped matrix multiply-accumulate + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &gmA, LayoutA const &layoutA, + AscendC::GlobalTensor const &gmB, LayoutB const &layoutB, + AscendC::GlobalTensor const &gmC, LayoutC const &layoutC, + AscendC::GlobalTensor const &gmBias, GemmCoord const &actualShape) + { + uint32_t mRound = RoundUp(actualShape.m()); + uint32_t nRound = RoundUp(actualShape.n()); + + auto layoutAInL1 = LayoutAInL1::template MakeLayout(L1TileShape::M, L1TileShape::K); + auto layoutBInL1 = LayoutBInL1::template MakeLayout(L1TileShape::K, L1TileShape::N); + auto layoutBiasInL1 = layout::VectorLayout(L1TileShape::N); + auto layoutBiasInL0 = layout::VectorLayout(L0TileShape::N); + auto layoutInL0C = LayoutCInL0::MakeLayoutInL0C(MakeCoord(mRound, nRound)); + + uint32_t kActual = min(actualShape.k(), L1TileShape::K); + + // load first matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListId]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); + copyGmToL1A(l1ATensorList[l1ListId], gmA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListId]); + + // load first matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListId]); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); + copyGmToL1B(l1BTensorList[l1ListId], gmB, layoutBInL1, layoutTileB); + auto layoutTileBias = layout::VectorLayout(actualShape.n()); + copyGmToL1Bias(l1BiasTensor, gmBias, layoutBiasInL1, layoutTileBias); + AscendC::SetFlag(l1BEventList[l1ListId]); + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::WaitFlag(EVENT_ID0); + } + + uint32_t mPartLoop = CeilDiv(mRound); + uint32_t nPartLoop = CeilDiv(nRound); + + // main loop + uint32_t kTileCount = CeilDiv(actualShape.k()); + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; kLoopIdx++) { + uint32_t l1ListIdNext = (l1ListId + 1 < STAGES) ? (l1ListId + 1) : 0; + uint32_t kActualNext{0}; + // preload next tile from GM to L1 + if (kLoopIdx < kTileCount - 1) { + uint32_t kLoopIdxNext = kLoopIdx + 1; + kActualNext = (kLoopIdxNext < kTileCount - 1) ? L1TileShape::K + : (actualShape.k() - kLoopIdxNext * L1TileShape::K); + + // Get L1 tensor for next stage + auto l1ATensor = l1ATensorList[l1ListIdNext]; + auto l1BTensor = l1BTensorList[l1ListIdNext]; + // Get GM tile for next stage + MatrixCoord gmTileAOffset{0, kLoopIdxNext * L1TileShape::K}; + MatrixCoord gmTileBOffset{kLoopIdxNext * L1TileShape::K, 0}; + auto gmTileA = gmA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTileB = gmB[layoutB.GetOffset(gmTileBOffset)]; + + // load next matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActualNext)); + copyGmToL1A(l1ATensor, gmTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + + // load next matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + layoutTileB = layoutB.GetTileLayout(MakeCoord(kActualNext, actualShape.n())); + copyGmToL1B(l1BTensor, gmTileB, layoutBInL1, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + } + + // Get L1 tensor for current stage + auto l1ATensor = l1ATensorList[l1ListId]; + auto l1BTensor = l1BTensorList[l1ListId]; + + // Get the loop nums on L0 + uint32_t kPartLoop = CeilDiv(kActual); + + for (int mPartIdx = 0; mPartIdx < mPartLoop; mPartIdx++) { + uint32_t mPartActual = + (mPartIdx < mPartLoop - 1) ? L0TileShape::M : (mRound - mPartIdx * L0TileShape::M); + + for (int kPartIdx = 0; kPartIdx < kPartLoop; kPartIdx++) { + uint32_t kPartActual = + (kPartIdx < kPartLoop - 1) ? L0TileShape::K : (kActual - kPartIdx * L0TileShape::K); + + // Locate the current tile on L0A + auto l0ATile = l0ATensorList[l0AListId]; + LayoutAInL0 layoutAInL0 = LayoutAInL0::template MakeLayout(mPartActual, kPartActual); + // Locate the current tile of matrix A on L1 + MatrixCoord l1AOffset{mPartIdx * L0TileShape::M, kPartIdx * L0TileShape::K}; + auto l1ATile = l1ATensor[layoutAInL1.GetOffset(l1AOffset)]; + + AscendC::WaitFlag(l0AEventList[l0AListId]); + if ((mPartIdx == 0) && (kPartIdx == 0)) { + AscendC::WaitFlag(l1AEventList[l1ListId]); + } + + // Load current tile from L1 to L0A + copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, layoutAInL1); + + if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) { + AscendC::SetFlag(l1AEventList[l1ListId]); + } + + for (int nPartIdx = 0; nPartIdx < nPartLoop; nPartIdx++) { + uint32_t nPartActual = + (nPartIdx < nPartLoop - 1) ? L0TileShape::N : (nRound - nPartIdx * L0TileShape::N); + + // Locate the current tile on L0B + auto l0BTile = l0BTensorList[l0BListId]; + LayoutBInL0 layoutBInL0 = LayoutBInL0::template MakeLayout(kPartActual, nPartActual); + // Locate the current tile of matrix B on L1 + MatrixCoord l1BOffset{kPartIdx * L0TileShape::K, nPartIdx * L0TileShape::N}; + auto l1BTile = l1BTensor[layoutBInL1.GetOffset(l1BOffset)]; + + // Wait for mmad finished + AscendC::WaitFlag(l0BEventList[l0BListId]); + // If the current tile is the first one on the k&n axis, wait for loading matrix B from GM to L1 + if ((kPartIdx == 0) && (nPartIdx == 0)) { + AscendC::WaitFlag(l1BEventList[l1ListId]); + } + + // Load current tile from L1 to L0B + copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, layoutBInL1); + + // Load bias to l0 biastable + copyL1ToBT(l0BiasTensor, l1BiasTensor, layoutBiasInL0, layoutBiasInL1); + + // If the current tile is the last one on the k&n axis, notify to load matrix B from GM to L1 + if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + AscendC::SetFlag(l1BEventList[l1ListId]); + } + // Notify to do mmad + AscendC::SetFlag(EVENT_ID0); + + // Locate the current tile on L0C + MatrixCoord l0COffset{mPartIdx * L0TileShape::M, nPartIdx * L0TileShape::N}; + auto l0CTile = l0CTensor[layoutInL0C.GetOffset(l0COffset)]; + + // Compute the matrix multiplication on L0A and L0B and write the result to the accumulator + // Wait for loading L0B + AscendC::WaitFlag(EVENT_ID0); + + // If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0 + bool initC = ((kLoopIdx == 0) && (kPartIdx == 0)); + // If the unit flag is enabled, the unit flag is set according to the calculation progress + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if ((kLoopIdx == kTileCount - 1) && (mPartIdx == mPartLoop - 1) && + (kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + // Perform calculation operations + if (initC) { + tileMmad(l0CTile, l0ATile, l0BTile, l0BiasTensor, mPartActual, nPartActual, kPartActual, + initC, unitFlag); + } else { + tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); + } + + // Notify to move the next L0B tile + AscendC::SetFlag(l0BEventList[l0BListId]); + l0BListId = (l0BListId + 1 < STAGES) ? (l0BListId + 1) : 0; + } + AscendC::SetFlag(l0AEventList[l0AListId]); + l0AListId = (l0AListId + 1 < STAGES) ? (l0AListId + 1) : 0; + } + } + l1ListId = l1ListIdNext; + kActual = kActualNext; + } + + // copy block out + LayoutC layoutBlock = layoutC.GetTileLayout(actualShape.GetCoordMN()); + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + copyL0CToGm(gmC, l0CTensor, layoutBlock, layoutInL0C); + AscendC::SetFlag(EVENT_ID0); + } else { + copyL0CToGm(gmC, l0CTensor, layoutBlock, layoutInL0C, 0b11); + } + } + +protected: + // Multi-stage tensors list + AscendC::LocalTensor l1ATensorList[STAGES]; + AscendC::LocalTensor l1BTensorList[STAGES]; + AscendC::LocalTensor l1BiasTensor; + AscendC::LocalTensor l0ATensorList[STAGES]; + AscendC::LocalTensor l0BTensorList[STAGES]; + AscendC::LocalTensor l0CTensor; + AscendC::LocalTensor l0BiasTensor; + + // Multi-stage event id list + int32_t l1AEventList[STAGES]; + int32_t l1BEventList[STAGES]; + int32_t l0AEventList[STAGES]; + int32_t l0BEventList[STAGES]; + + // The id of current stage + uint32_t l1ListId{0}; + uint32_t l0AListId{0}; + uint32_t l0BListId{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyGmToL1Bias copyGmToL1Bias; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL1ToBT copyL1ToBT; + CopyL0CToGm copyL0CToGm; +}; + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PINGPONG_BIAS_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_pingpong_full_loadA.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_pingpong_full_loadA.hpp new file mode 100644 index 00000000..70459f9b --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_pingpong_full_loadA.hpp @@ -0,0 +1,323 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_PINGPONG_FULL_LOADA_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_PINGPONG_FULL_LOADA_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +namespace Catlass::Gemm::Block { + +template +struct BlockMmad, L1TileShape_, L0TileShape_, AType_, BType_, CType_, BiasType_, + TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2FullLoadA; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L1A_SIZE = ArchTag::L1_SIZE / 2; + static constexpr uint32_t L1B_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; + static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; + static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; + static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; + static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES; + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + // Check L1TileShape in example, because problemShape.k() cannot get here. + + // Check L0TileShape + static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA); + static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB); + static constexpr uint32_t L0C_TILE_SIZE = L0TileShape::M * L0TileShape::N * sizeof(ElementAccumulator); + static_assert((L0A_TILE_SIZE * STAGES) <= L0A_SIZE, "L0TileShape exceeding the L0A space!"); + static_assert((L0B_TILE_SIZE * STAGES) <= L0B_SIZE, "L0TileShape exceeding the L0B space!"); + static_assert(L0C_TILE_SIZE <= L0C_SIZE, "L0TileShape exceeding the L0C space!"); + + static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N, + "The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet"); + + /// Construct + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + uint32_t l1AOffset = l1BufAddrStart; + uint32_t l1BOffset = l1BufAddrStart + L1A_SIZE; + // Init L1A related + l1ATensorList[0] = resource.l1Buf.template GetBufferByByte(l1AOffset); + l1AEventList[0] = 0; + + // Init buffers + for (uint32_t i = 0; i < STAGES; i++) { + // Assign L1/L0A/L0B space for each stages + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_SIZE * i); + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_PINGPONG_BUF_SIZE * i); + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_PINGPONG_BUF_SIZE * i); + + // Assign event ID for each stages + l1BEventList[i] = i + STAGES; + l0AEventList[i] = i; + l0BEventList[i] = i + STAGES; + + // The event id that needs to be set before the loop + AscendC::SetFlag(l1BEventList[i]); + AscendC::SetFlag(l0AEventList[i]); + AscendC::SetFlag(l0BEventList[i]); + } + l0CTensor = resource.l0CBuf.template GetBufferByByte(0); + AscendC::SetFlag(EVENT_ID0); + } + + /// Destructor + CATLASS_DEVICE + ~BlockMmad() + { + for (uint32_t i = 0; i < STAGES; i++) { + AscendC::WaitFlag(l1BEventList[i]); + AscendC::WaitFlag(l0AEventList[i]); + AscendC::WaitFlag(l0BEventList[i]); + } + AscendC::WaitFlag(EVENT_ID0); + } + + /// Perform a block-scoped matrix multiply-accumulate + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &gmA, LayoutA const &layoutA, + AscendC::GlobalTensor const &gmB, LayoutB const &layoutB, + AscendC::GlobalTensor const &gmC, LayoutC const &layoutC, GemmCoord const &actualShape, + bool needLoadL1) + { + uint32_t mRound = RoundUp(actualShape.m()); + uint32_t nRound = RoundUp(actualShape.n()); + + auto layoutAInL1 = LayoutAInL1::template MakeLayout(L1TileShape::M, actualShape.k()); + auto layoutBInL1 = LayoutBInL1::template MakeLayout(L1TileShape::K, L1TileShape::N); + auto layoutInL0C = LayoutCInL0::MakeLayoutInL0C(MakeCoord(mRound, nRound)); + + uint32_t kActual = min(actualShape.k(), L1TileShape::K); + + if (needLoadL1) { + // load first matrix A tile from GM to L1 + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), actualShape.k())); + AscendC::SetFlag(l1AEventList[0]); + AscendC::WaitFlag(l1AEventList[0]); + copyGmToL1A(l1ATensorList[0], gmA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[0]); + AscendC::WaitFlag(l1AEventList[0]); + } + + // load first matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListId]); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); + copyGmToL1B(l1BTensorList[l1ListId], gmB, layoutBInL1, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListId]); + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::WaitFlag(EVENT_ID0); + } + + uint32_t mPartLoop = CeilDiv(mRound); + uint32_t nPartLoop = CeilDiv(nRound); + + // main loop + uint32_t kTileCount = CeilDiv(actualShape.k()); + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; kLoopIdx++) { + uint32_t l1ListIdNext = (l1ListId + 1 < STAGES) ? (l1ListId + 1) : 0; + uint32_t kActualNext{0}; + // preload next tile from GM to L1 + if (kLoopIdx < kTileCount - 1) { + uint32_t kLoopIdxNext = kLoopIdx + 1; + kActualNext = (kLoopIdxNext < kTileCount - 1) ? L1TileShape::K + : (actualShape.k() - kLoopIdxNext * L1TileShape::K); + + // Get L1 tensor for next stage + auto l1BTensor = l1BTensorList[l1ListIdNext]; + // Get GM tile for next stage + MatrixCoord gmTileBOffset{kLoopIdxNext * L1TileShape::K, 0}; + auto gmTileB = gmB[layoutB.GetOffset(gmTileBOffset)]; + + // load next matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + layoutTileB = layoutB.GetTileLayout(MakeCoord(kActualNext, actualShape.n())); + copyGmToL1B(l1BTensor, gmTileB, layoutBInL1, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + } + + // Get L1 tensor for current stage + auto l1ATensor = l1ATensorList[0]; + auto l1BTensor = l1BTensorList[l1ListId]; + + // Get the loop nums on L0 + uint32_t kPartLoop = CeilDiv(kActual); + + for (int mPartIdx = 0; mPartIdx < mPartLoop; mPartIdx++) { + uint32_t mPartActual = + (mPartIdx < mPartLoop - 1) ? L0TileShape::M : (mRound - mPartIdx * L0TileShape::M); + + for (int kPartIdx = 0; kPartIdx < kPartLoop; kPartIdx++) { + uint32_t kPartActual = + (kPartIdx < kPartLoop - 1) ? L0TileShape::K : (kActual - kPartIdx * L0TileShape::K); + + // Locate the current tile on L0A + auto l0ATile = l0ATensorList[l0AListId]; + LayoutAInL0 layoutAInL0 = LayoutAInL0::template MakeLayout(mPartActual, kPartActual); + // Locate the current tile of matrix A on L1 + MatrixCoord l1AOffset{mPartIdx * L0TileShape::M, + kPartIdx * L0TileShape::K + kLoopIdx * L1TileShape::K}; + auto l1ATile = l1ATensor[layoutAInL1.GetOffset(l1AOffset)]; + + AscendC::WaitFlag(l0AEventList[l0AListId]); + + // Load current tile from L1 to L0A + copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, layoutAInL1); + + for (int nPartIdx = 0; nPartIdx < nPartLoop; nPartIdx++) { + uint32_t nPartActual = + (nPartIdx < nPartLoop - 1) ? L0TileShape::N : (nRound - nPartIdx * L0TileShape::N); + + // Locate the current tile on L0B + auto l0BTile = l0BTensorList[l0BListId]; + LayoutBInL0 layoutBInL0 = LayoutBInL0::template MakeLayout(kPartActual, nPartActual); + // Locate the current tile of matrix B on L1 + MatrixCoord l1BOffset{kPartIdx * L0TileShape::K, nPartIdx * L0TileShape::N}; + auto l1BTile = l1BTensor[layoutBInL1.GetOffset(l1BOffset)]; + + // Wait for mmad finished + AscendC::WaitFlag(l0BEventList[l0BListId]); + // If the current tile is the first one on the k&n axis, wait for loading matrix B from GM to L1 + if ((kPartIdx == 0) && (nPartIdx == 0)) { + AscendC::WaitFlag(l1BEventList[l1ListId]); + } + + // Load current tile from L1 to L0B + copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, layoutBInL1); + + // If the current tile is the last one on the k&n axis, notify to load matrix B from GM to L1 + if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + AscendC::SetFlag(l1BEventList[l1ListId]); + } + // Notify to do mmad + AscendC::SetFlag(EVENT_ID0); + + // Locate the current tile on L0C + MatrixCoord l0COffset{mPartIdx * L0TileShape::M, nPartIdx * L0TileShape::N}; + auto l0CTile = l0CTensor[layoutInL0C.GetOffset(l0COffset)]; + + // Compute the matrix multiplication on L0A and L0B and write the result to the accumulator + // Wait for loading L0B + AscendC::WaitFlag(EVENT_ID0); + + // If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0 + bool initC = ((kLoopIdx == 0) && (kPartIdx == 0)); + // If the unit flag is enabled, the unit flag is set according to the calculation progress + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if ((kLoopIdx == kTileCount - 1) && (mPartIdx == mPartLoop - 1) && + (kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + // Perform calculation operations + tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); + + // Notify to move the next L0B tile + AscendC::SetFlag(l0BEventList[l0BListId]); + l0BListId = (l0BListId + 1 < STAGES) ? (l0BListId + 1) : 0; + } + AscendC::SetFlag(l0AEventList[l0AListId]); + l0AListId = (l0AListId + 1 < STAGES) ? (l0AListId + 1) : 0; + } + } + l1ListId = l1ListIdNext; + kActual = kActualNext; + } + + // copy block out + LayoutC layoutBlock = layoutC.GetTileLayout(actualShape.GetCoordMN()); + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + copyL0CToGm(gmC, l0CTensor, layoutBlock, layoutInL0C); + AscendC::SetFlag(EVENT_ID0); + } else { + copyL0CToGm(gmC, l0CTensor, layoutBlock, layoutInL0C, 0b11); + } + } + +protected: + // Multi-stage tensors list + AscendC::LocalTensor l1ATensorList[STAGES]; + AscendC::LocalTensor l1BTensorList[STAGES]; + AscendC::LocalTensor l0ATensorList[STAGES]; + AscendC::LocalTensor l0BTensorList[STAGES]; + AscendC::LocalTensor l0CTensor; + + // Multi-stage event id list + int32_t l1AEventList[STAGES]; + int32_t l1BEventList[STAGES]; + int32_t l0AEventList[STAGES]; + int32_t l0BEventList[STAGES]; + + // The id of current stage + uint32_t l1ListId{0}; + uint32_t l0AListId{0}; + uint32_t l0BListId{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PINGPONG_FULL_LOADA_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_pingpong_slice_k.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_pingpong_slice_k.hpp new file mode 100644 index 00000000..3431ed70 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_pingpong_slice_k.hpp @@ -0,0 +1,359 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_PINGPONG_SLICE_K_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_PINGPONG_SLICE_K_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm/block/block_dequant.hpp" + +namespace Catlass::Gemm::Block { + +template +struct BlockMmad, L1TileShape_, L0TileShape_, AType_, BType_, CType_, + BiasType_, TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2PingpongSliceK; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; + static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; + static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; + static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES; + + // Check L1TileShape + static_assert((L1A_SIZE * STAGES + L1B_SIZE * STAGES) <= ArchTag::L1_SIZE, "L1TileShape exceeding the L1 space!"); + + // Check L0TileShape + static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA); + static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB); + static_assert((L0A_TILE_SIZE * STAGES) <= L0A_SIZE, "L0TileShape exceeding the L0A space!"); + static_assert((L0B_TILE_SIZE * STAGES) <= L0B_SIZE, "L0TileShape exceeding the L0B space!"); + + /// Construct + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + uint32_t l1AOffset = l1BufAddrStart; + uint32_t l1BOffset = l1BufAddrStart + L1A_SIZE * STAGES; + // Init buffers + for (uint32_t i = 0; i < STAGES; i++) { + // Assign L1/L0A/L0B space for each stages + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1A_SIZE * i); + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_SIZE * i); + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_PINGPONG_BUF_SIZE * i); + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_PINGPONG_BUF_SIZE * i); + + // Assign event ID for each stages + l1AEventList[i] = i; + l1BEventList[i] = i + STAGES; + l0AEventList[i] = i; + l0BEventList[i] = i + STAGES; + + // The event id that needs to be set before the loop + AscendC::SetFlag(l1AEventList[i]); + AscendC::SetFlag(l1BEventList[i]); + AscendC::SetFlag(l0AEventList[i]); + AscendC::SetFlag(l0BEventList[i]); + } + l0CTensor = resource.l0CBuf.template GetBufferByByte(0); + AscendC::SetFlag(EVENT_ID0); + } + + /// Destructor + CATLASS_DEVICE + ~BlockMmad() + { + for (uint32_t i = 0; i < STAGES; i++) { + AscendC::WaitFlag(l1AEventList[i]); + AscendC::WaitFlag(l1BEventList[i]); + AscendC::WaitFlag(l0AEventList[i]); + AscendC::WaitFlag(l0BEventList[i]); + } + AscendC::WaitFlag(EVENT_ID0); + } + + /// Perform a block-scoped matrix multiply-accumulate + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &gmWA, LayoutA const &layoutA, + AscendC::GlobalTensor const &gmWB, LayoutB const &layoutB, + AscendC::GlobalTensor const &gmC, LayoutC const &layoutC, + AscendC::GlobalTensor const &gmNextWA, LayoutA const &layoutNextA, + AscendC::GlobalTensor const &gmNextWB, LayoutB const &layoutNextB, + GemmCoord const &actualShape, GemmCoord const &nextActualShape, bool isFirstKSlice, + bool isFirstBlock, bool hasNextBlock) + { + uint32_t mRound = RoundUp(actualShape.m()); + uint32_t nRound = RoundUp(actualShape.n()); + + auto layoutAInL1 = LayoutAInL1::template MakeLayout(L1TileShape::M, L1TileShape::K); + auto layoutBInL1 = LayoutBInL1::template MakeLayout(L1TileShape::K, L1TileShape::N); + auto layoutInL0C = LayoutCInL0::MakeLayoutInL0C(MakeCoord(mRound, nRound)); + + uint32_t kActual = min(actualShape.k(), L1TileShape::K); + + // load first matrix A tile from GM to L1 + if (isFirstBlock) { + AscendC::WaitFlag(l1AEventList[l1ListId]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); + copyGmToL1A(l1ATensorList[l1ListId], gmWA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListId]); + + // load first matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListId]); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); + copyGmToL1B(l1BTensorList[l1ListId], gmWB, layoutBInL1, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListId]); + } + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::WaitFlag(EVENT_ID0); + } + + uint32_t mPartLoop = CeilDiv(mRound); + uint32_t nPartLoop = CeilDiv(nRound); + + // main loop + uint32_t kTileCount = CeilDiv(actualShape.k()); + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; kLoopIdx++) { + uint32_t l1ListIdNext = (l1ListId + 1 < STAGES) ? (l1ListId + 1) : 0; + uint32_t kActualNext{0}; + // preload next tile from GM to L1 + if (kLoopIdx < kTileCount - 1) { + uint32_t kLoopIdxNext = kLoopIdx + 1; + kActualNext = (kLoopIdxNext < kTileCount - 1) ? L1TileShape::K + : (actualShape.k() - kLoopIdxNext * L1TileShape::K); + + // Get L1 tensor for next stage + auto l1ATensor = l1ATensorList[l1ListIdNext]; + auto l1BTensor = l1BTensorList[l1ListIdNext]; + // Get GM tile for next stage + MatrixCoord gmTileAOffset{0, kLoopIdxNext * L1TileShape::K}; + MatrixCoord gmTileBOffset{kLoopIdxNext * L1TileShape::K, 0}; + auto gmTileA = gmWA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTileB = gmWB[layoutB.GetOffset(gmTileBOffset)]; + + // load next matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActualNext)); + copyGmToL1A(l1ATensor, gmTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + + // load next matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActualNext, actualShape.n())); + copyGmToL1B(l1BTensor, gmTileB, layoutBInL1, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + } else if (hasNextBlock) { + uint32_t kLoopIdxNext = 0; + kActualNext = (kLoopIdxNext < kTileCount - 1) ? L1TileShape::K : (nextActualShape.k()); + + // Get L1 tensor for next stage + auto l1ATensor = l1ATensorList[l1ListIdNext]; + auto l1BTensor = l1BTensorList[l1ListIdNext]; + // Get GM tile for next stage + MatrixCoord gmTileAOffset{0, kLoopIdxNext * L1TileShape::K}; + MatrixCoord gmTileBOffset{kLoopIdxNext * L1TileShape::K, 0}; + auto gmTileA = gmNextWA[layoutNextA.GetOffset(gmTileAOffset)]; + auto gmTileB = gmNextWB[layoutNextB.GetOffset(gmTileBOffset)]; + + // load next matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + auto layoutTileA = layoutNextA.GetTileLayout(MakeCoord(nextActualShape.m(), kActualNext)); + copyGmToL1A(l1ATensor, gmTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + + // load next matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + auto layoutTileB = layoutNextB.GetTileLayout(MakeCoord(kActualNext, nextActualShape.n())); + copyGmToL1B(l1BTensor, gmTileB, layoutBInL1, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + } + + // Get L1 tensor for current stage + auto l1ATensor = l1ATensorList[l1ListId]; + auto l1BTensor = l1BTensorList[l1ListId]; + + // Get the loop nums on L0 + uint32_t kPartLoop = CeilDiv(kActual); + + for (int mPartIdx = 0; mPartIdx < mPartLoop; mPartIdx++) { + uint32_t mPartActual = + (mPartIdx < mPartLoop - 1) ? L0TileShape::M : (mRound - mPartIdx * L0TileShape::M); + + for (int kPartIdx = 0; kPartIdx < kPartLoop; kPartIdx++) { + uint32_t kPartActual = + (kPartIdx < kPartLoop - 1) ? L0TileShape::K : (kActual - kPartIdx * L0TileShape::K); + + // Locate the current tile on L0A + auto l0ATile = l0ATensorList[l0AListId]; + LayoutAInL0 layoutAInL0 = LayoutAInL0::template MakeLayout(mPartActual, kPartActual); + // Locate the current tile of matrix A on L1 + MatrixCoord l1AOffset{mPartIdx * L0TileShape::M, kPartIdx * L0TileShape::K}; + auto l1ATile = l1ATensor[layoutAInL1.GetOffset(l1AOffset)]; + + AscendC::WaitFlag(l0AEventList[l0AListId]); + if ((mPartIdx == 0) && (kPartIdx == 0)) { + AscendC::WaitFlag(l1AEventList[l1ListId]); + } + + // Load current tile from L1 to L0A + copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, layoutAInL1); + + if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) { + AscendC::SetFlag(l1AEventList[l1ListId]); + } + + for (int nPartIdx = 0; nPartIdx < nPartLoop; nPartIdx++) { + uint32_t nPartActual = + (nPartIdx < nPartLoop - 1) ? L0TileShape::N : (nRound - nPartIdx * L0TileShape::N); + + // Locate the current tile on L0B + auto l0BTile = l0BTensorList[l0BListId]; + LayoutBInL0 layoutBInL0 = LayoutBInL0::template MakeLayout(kPartActual, nPartActual); + // Locate the current tile of matrix B on L1 + MatrixCoord l1BOffset{kPartIdx * L0TileShape::K, nPartIdx * L0TileShape::N}; + auto l1BTile = l1BTensor[layoutBInL1.GetOffset(l1BOffset)]; + + // Wait for mmad finished + AscendC::WaitFlag(l0BEventList[l0BListId]); + // If the current tile is the first one on the k&n axis, wait for loading matrix B from GM to L1 + if ((kPartIdx == 0) && (nPartIdx == 0)) { + AscendC::WaitFlag(l1BEventList[l1ListId]); + } + + // Load current tile from L1 to L0B + copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, layoutBInL1); + + // If the current tile is the last one on the k&n axis, notify to load matrix B from GM to L1 + if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + AscendC::SetFlag(l1BEventList[l1ListId]); + } + // Notify to do mmad + AscendC::SetFlag(EVENT_ID0); + + // Locate the current tile on L0C + MatrixCoord l0COffset{mPartIdx * L0TileShape::M, nPartIdx * L0TileShape::N}; + auto l0CTile = l0CTensor[layoutInL0C.GetOffset(l0COffset)]; + + // Compute the matrix multiplication on L0A and L0B and write the result to the accumulator + // Wait for loading L0B + AscendC::WaitFlag(EVENT_ID0); + + // If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0 + bool initC = ((kLoopIdx == 0) && (kPartIdx == 0)); + // If the unit flag is enabled, the unit flag is set according to the calculation progress + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if ((kLoopIdx == kTileCount - 1) && (mPartIdx == mPartLoop - 1) && + (kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + // Perform calculation operations + tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); + + // Notify to move the next L0B tile + AscendC::SetFlag(l0BEventList[l0BListId]); + l0BListId = (l0BListId + 1 < STAGES) ? (l0BListId + 1) : 0; + } + AscendC::SetFlag(l0AEventList[l0AListId]); + l0AListId = (l0AListId + 1 < STAGES) ? (l0AListId + 1) : 0; + } + } + l1ListId = l1ListIdNext; + kActual = kActualNext; + } + + // copy block out + LayoutC layoutBlock = layoutC.GetTileLayout(actualShape.GetCoordMN()); + + if (!isFirstKSlice) { // 切K后非第一块 开启原子加 + AscendC::SetAtomicAdd(); + } + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + copyL0CToGm(gmC, l0CTensor, layoutBlock, layoutInL0C); + AscendC::SetFlag(EVENT_ID0); + } else { + copyL0CToGm(gmC, l0CTensor, layoutBlock, layoutInL0C, 0b11); + } + if (!isFirstKSlice) { + AscendC::SetAtomicNone(); + } + } + +protected: + // Multi-stage tensors list + AscendC::LocalTensor l1ATensorList[STAGES]; + AscendC::LocalTensor l1BTensorList[STAGES]; + AscendC::LocalTensor l0ATensorList[STAGES]; + AscendC::LocalTensor l0BTensorList[STAGES]; + AscendC::LocalTensor l0CTensor; + + // Multi-stage event id list + int32_t l1AEventList[STAGES]; + int32_t l1BEventList[STAGES]; + int32_t l0AEventList[STAGES]; + int32_t l0BEventList[STAGES]; + + // The id of current stage + uint32_t l1ListId{0}; + uint32_t l0AListId{0}; + uint32_t l0BListId{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +} // namespace Catlass::Gemm::Block +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PINGPONG_SLICE_K_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_pingpong_tla.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_pingpong_tla.hpp new file mode 100644 index 00000000..9dcc09ef --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_pingpong_tla.hpp @@ -0,0 +1,354 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_PINGPONG_TLA_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_PINGPONG_TLA_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" +#include "tla/layout.hpp" +#include "tla/tensor.hpp" + +namespace Catlass::Gemm::Block { + +template +struct BlockMmadTla, L1TileShape_, L0TileShape_, ElementA_, ElementB_, ElementC_, + ElementBias_, TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2Pingpong; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = ElementA_; + using LayoutA = typename TileCopy_::LayoutA; + using ElementB = ElementB_; + using LayoutB = typename TileCopy_::LayoutB; + using ElementC = ElementC_; + using LayoutC = typename TileCopy_::LayoutC; + + using TileMmad = TileMmad_; + + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + + using ElementAccumulator = typename TileCopy_::ElementAccumulator; + + using LayoutTagL1A = typename TileCopy_::LayoutTagL1A; + using LayoutTagL1B = typename TileCopy_::LayoutTagL1B; + using LayoutTagL0A = typename TileCopy_::LayoutTagL0A; + using LayoutTagL0B = typename TileCopy_::LayoutTagL0B; + + using L1AAlignHelper = typename TileCopy_::L1AAlignHelper; + using L1BAlignHelper = typename TileCopy_::L1BAlignHelper; + + static_assert(tla::is_tuple::value && tla::is_static::value, + "L1TileShape must be tla::tuple and static!"); + static_assert(tla::is_tuple::value && tla::is_static::value, + "L0TileShape must be tla::tuple and static!"); + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L1_TILE_M = tla::get<0>(L1TileShape{}); + static constexpr uint32_t L1_TILE_N = tla::get<1>(L1TileShape{}); + static constexpr uint32_t L1_TILE_K = tla::get<2>(L1TileShape{}); + static constexpr uint32_t L0_TILE_M = tla::get<0>(L0TileShape{}); + static constexpr uint32_t L0_TILE_N = tla::get<1>(L0TileShape{}); + static constexpr uint32_t L0_TILE_K = tla::get<2>(L0TileShape{}); + + // L1 tile size + static constexpr uint32_t L1A_TILE_SIZE = L1_TILE_M * L1_TILE_K * sizeof(ElementA); + static constexpr uint32_t L1B_TILE_SIZE = L1_TILE_N * L1_TILE_K * sizeof(ElementB); + // L0 tile size + static constexpr uint32_t L0A_TILE_SIZE = L0_TILE_M * L0_TILE_K * sizeof(ElementA); + static constexpr uint32_t L0B_TILE_SIZE = L0_TILE_K * L0_TILE_N * sizeof(ElementB); + static constexpr uint32_t L0C_TILE_SIZE = L1_TILE_M * L1_TILE_N * sizeof(ElementAccumulator); + + // Check LayoutC + static_assert(tla::detail::isRowMajor::value, "LayoutC only support RowMajor yet!"); + + // Check L1TileShape + static_assert((L1A_TILE_SIZE + L1B_TILE_SIZE) * STAGES <= ArchTag::L1_SIZE, "L1TileShape exceeding the L1 space!"); + + // Check L0TileShape + static_assert(L0A_TILE_SIZE * STAGES <= ArchTag::L0A_SIZE, "L0TileShape exceeding the L0A space!"); + static_assert(L0B_TILE_SIZE * STAGES <= ArchTag::L0B_SIZE, "L0TileShape exceeding the L0B space!"); + static_assert(L0C_TILE_SIZE <= ArchTag::L0C_SIZE, "L0TileShape exceeding the L0C space!"); + + static_assert(L1_TILE_M == L0_TILE_M && L1_TILE_N == L0_TILE_N, + "The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet"); + static_assert(L0_TILE_K <= L1_TILE_K, "L0TileShape::K cannot exceed L1TileShape::K"); + + static constexpr auto L1A_LAYOUT = tla::MakeLayout(Int{}, Int{}); + static constexpr auto L1B_LAYOUT = tla::MakeLayout(Int{}, Int{}); + + /// Construct + CATLASS_DEVICE + BlockMmadTla(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + uint32_t l1AOffset = l1BufAddrStart; + uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * STAGES; + // Init buffers + for (uint32_t i = 0; i < STAGES; i++) { + // Assign L1/L0A/L0B space for each stages + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1A_TILE_SIZE * i); + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_TILE_SIZE * i); + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_TILE_SIZE * i); + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_TILE_SIZE * i); + + // Assign event ID for each stages + l1AEventList[i] = i; + l1BEventList[i] = i + STAGES; + l0AEventList[i] = i; + l0BEventList[i] = i + STAGES; + + // The event id that needs to be set before the loop + AscendC::SetFlag(l1AEventList[i]); + AscendC::SetFlag(l1BEventList[i]); + AscendC::SetFlag(l0AEventList[i]); + AscendC::SetFlag(l0BEventList[i]); + } + l0CTensor = resource.l0CBuf.template GetBufferByByte(0); + AscendC::SetFlag(EVENT_ID0); + } + + /// Destructor + CATLASS_DEVICE + ~BlockMmadTla() + { + for (uint32_t i = 0; i < STAGES; i++) { + AscendC::WaitFlag(l1AEventList[i]); + AscendC::WaitFlag(l1BEventList[i]); + AscendC::WaitFlag(l0AEventList[i]); + AscendC::WaitFlag(l0BEventList[i]); + } + AscendC::WaitFlag(EVENT_ID0); + } + + /// Perform a block-scoped matrix multiply-accumulate + template + CATLASS_DEVICE void operator()(TensorA &tensorA, TensorB &tensorB, TensorC &tensorC, GemmCoord const &actualShape) + { + using CopyGmToL1A = typename TileCopy_::template CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::template CopyGmToL1B; + using CopyL0CToGm = typename TileCopy_::template CopyL0CToGm; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL0CToGm copyL0CToGm; + + uint32_t mBlockActual = actualShape.m(); + uint32_t kBlockActual = actualShape.k(); + uint32_t nBlockActual = actualShape.n(); + + uint32_t mRound = RoundUp(mBlockActual); + uint32_t nRound = RoundUp(nBlockActual); + + auto layoutInL0C = tla::MakeLayoutL0C(mRound, nRound); + auto tensorL0C = tla::MakeTensor(l0CTensor, layoutInL0C, Arch::PositionL0C{}); + + uint32_t kActual = min(kBlockActual, L1_TILE_K); + // load first matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListId]); + auto tensorL1A = tla::MakeTensor(l1ATensorList[l1ListId], L1A_LAYOUT, Arch::PositionL1{}); + auto tensorTileA = GetTile(tensorA, tla::MakeCoord(0, 0), tla::MakeShape(mBlockActual, kActual)); + copyGmToL1A(tensorL1A, tensorTileA); + AscendC::SetFlag(l1AEventList[l1ListId]); + + // load first matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListId]); + auto tensorL1B = tla::MakeTensor(l1BTensorList[l1ListId], L1B_LAYOUT, Arch::PositionL1{}); + auto tensorTileB = GetTile(tensorB, tla::MakeCoord(0, 0), tla::MakeShape(kActual, nBlockActual)); + copyGmToL1B(tensorL1B, tensorTileB); + AscendC::SetFlag(l1BEventList[l1ListId]); + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::WaitFlag(EVENT_ID0); + } + + uint32_t mPartLoop = CeilDiv(mRound); + uint32_t nPartLoop = CeilDiv(nRound); + + // main loop + uint32_t kTileCount = CeilDiv(kBlockActual); + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; kLoopIdx++) { + uint32_t l1ListIdNext = (l1ListId + 1 < STAGES) ? (l1ListId + 1) : 0; + uint32_t kActualNext{0}; + // preload next tile from GM to L1 + if (kLoopIdx < kTileCount - 1) { + uint32_t kLoopIdxNext = kLoopIdx + 1; + kActualNext = (kLoopIdxNext < kTileCount - 1) ? L1_TILE_K : (kBlockActual - kLoopIdxNext * L1_TILE_K); + + // Get L1 tensor for next stage + auto l1ATensor = l1ATensorList[l1ListIdNext]; + auto l1BTensor = l1BTensorList[l1ListIdNext]; + auto tensorL1A = tla::MakeTensor(l1ATensor, L1A_LAYOUT, Arch::PositionL1{}); + auto tensorL1B = tla::MakeTensor(l1BTensor, L1B_LAYOUT, Arch::PositionL1{}); + // Get GM tile for next stage + auto tensorTileA = GetTile(tensorA, tla::MakeCoord(0, kLoopIdxNext * L1_TILE_K), + tla::MakeShape(mBlockActual, kActualNext)); + auto tensorTileB = GetTile(tensorB, tla::MakeCoord(kLoopIdxNext * L1_TILE_K, 0), + tla::MakeShape(kActualNext, nBlockActual)); + + // load next matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + copyGmToL1A(tensorL1A, tensorTileA); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + + // load next matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + copyGmToL1B(tensorL1B, tensorTileB); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + } + + // Get L1 tensor for current stage + auto l1ATensor = l1ATensorList[l1ListId]; + auto l1BTensor = l1BTensorList[l1ListId]; + tensorL1A = tla::MakeTensor(l1ATensor, L1A_LAYOUT, Arch::PositionL1{}); + tensorL1B = tla::MakeTensor(l1BTensor, L1B_LAYOUT, Arch::PositionL1{}); + // Get the loop nums on L0 + uint32_t kPartLoop = CeilDiv(kActual); + + for (int mPartIdx = 0; mPartIdx < mPartLoop; mPartIdx++) { + uint32_t mPartActual = (mPartIdx < mPartLoop - 1) ? L0_TILE_M : (mRound - mPartIdx * L0_TILE_M); + + for (int kPartIdx = 0; kPartIdx < kPartLoop; kPartIdx++) { + uint32_t kPartActual = (kPartIdx < kPartLoop - 1) ? L0_TILE_K : (kActual - kPartIdx * L0_TILE_K); + + // Locate the current tile on L0A + auto l0ATile = l0ATensorList[l0AListId]; + auto layoutAInL0 = tla::MakeLayout(mPartActual, kPartActual); + auto tensorL0A = tla::MakeTensor(l0ATile, layoutAInL0, Arch::PositionL0A{}); + // Locate the current tile of matrix A on L1 + auto tensorTileL1A = GetTile(tensorL1A, tla::MakeCoord(mPartIdx * L0_TILE_M, kPartIdx * L0_TILE_K), + tla::MakeShape(mPartActual, kPartActual)); + + AscendC::WaitFlag(l0AEventList[l0AListId]); + if ((mPartIdx == 0) && (kPartIdx == 0)) { + AscendC::WaitFlag(l1AEventList[l1ListId]); + } + + // Load current tile from L1 to L0A + copyL1ToL0A(tensorL0A, tensorTileL1A); + + if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) { + AscendC::SetFlag(l1AEventList[l1ListId]); + } + + for (int nPartIdx = 0; nPartIdx < nPartLoop; nPartIdx++) { + uint32_t nPartActual = (nPartIdx < nPartLoop - 1) ? L0_TILE_N : (nRound - nPartIdx * L0_TILE_N); + + // Locate the current tile on L0B + auto l0BTile = l0BTensorList[l0BListId]; + auto layoutBInL0 = tla::MakeLayout(kPartActual, nPartActual); + auto tensorL0B = tla::MakeTensor(l0BTile, layoutBInL0, Arch::PositionL0B{}); + // Locate the current tile of matrix B on L1 + auto tensorTileL1B = + GetTile(tensorL1B, tla::MakeCoord(kPartIdx * L0_TILE_K, nPartIdx * L0_TILE_N), + tla::MakeShape(kPartActual, nPartActual)); + + // Wait for mmad finished + AscendC::WaitFlag(l0BEventList[l0BListId]); + // If the current tile is the first one on the k&n axis, wait for loading matrix B from GM to L1 + if ((kPartIdx == 0) && (nPartIdx == 0)) { + AscendC::WaitFlag(l1BEventList[l1ListId]); + } + + // Load current tile from L1 to L0B + copyL1ToL0B(tensorL0B, tensorTileL1B); + + // If the current tile is the last one on the k&n axis, notify to load matrix B from GM to L1 + if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + AscendC::SetFlag(l1BEventList[l1ListId]); + } + // Notify to do mmad + AscendC::SetFlag(EVENT_ID0); + + // Locate the current tile on L0C + auto tensorTileL0C = + GetTile(tensorL0C, tla::MakeCoord(mPartIdx * L0_TILE_M, nPartIdx * L0_TILE_N), + tla::MakeShape(mPartActual, nPartActual)); + + // Compute the matrix multiplication on L0A and L0B and write the result to the accumulator + // Wait for loading L0B + AscendC::WaitFlag(EVENT_ID0); + + // If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0 + bool initC = ((kLoopIdx == 0) && (kPartIdx == 0)); + // If the unit flag is enabled, the unit flag is set according to the calculation progress + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if ((kLoopIdx == kTileCount - 1) && (mPartIdx == mPartLoop - 1) && + (kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + // Perform calculation operations + tileMmad(tensorTileL0C, tensorL0A, tensorL0B, mPartActual, nPartActual, kPartActual, initC, + unitFlag); + + // Notify to move the next L0B tile + AscendC::SetFlag(l0BEventList[l0BListId]); + l0BListId = (l0BListId + 1 < STAGES) ? (l0BListId + 1) : 0; + } + AscendC::SetFlag(l0AEventList[l0AListId]); + l0AListId = (l0AListId + 1 < STAGES) ? (l0AListId + 1) : 0; + } + } + l1ListId = l1ListIdNext; + kActual = kActualNext; + } + + // copy block out + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + copyL0CToGm(tensorC, tensorL0C); + AscendC::SetFlag(EVENT_ID0); + } else { + copyL0CToGm(tensorC, tensorL0C, 0b11); + } + } + +protected: + // Multi-stage tensors list + AscendC::LocalTensor l1ATensorList[STAGES]; + AscendC::LocalTensor l1BTensorList[STAGES]; + AscendC::LocalTensor l0ATensorList[STAGES]; + AscendC::LocalTensor l0BTensorList[STAGES]; + AscendC::LocalTensor l0CTensor; + + // Multi-stage event id list + int32_t l1AEventList[STAGES]; + int32_t l1BEventList[STAGES]; + int32_t l0AEventList[STAGES]; + int32_t l0BEventList[STAGES]; + + // The id of current stage + uint32_t l1ListId{0}; + uint32_t l0AListId{0}; + uint32_t l0BListId{0}; + + TileMmad tileMmad; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; +}; + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PINGPONG_TLA_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_preload.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_preload.hpp new file mode 100644 index 00000000..04250cb1 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_preload.hpp @@ -0,0 +1,376 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +namespace Catlass::Gemm::Block { + +template +struct BlockMmad, L1TileShape_, L0TileShape_, AType_, BType_, + CType_, BiasType_, TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2Preload; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K; + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; + static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; + static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; + static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; + static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES; + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + // Check L1TileShape + static_assert((L1A_SIZE * STAGES + L1B_SIZE * STAGES) <= ArchTag::L1_SIZE, "L1TileShape exceeding the L1 space!"); + + // Check L0TileShape + static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA); + static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB); + static constexpr uint32_t L0C_TILE_SIZE = L0TileShape::M * L0TileShape::N * sizeof(ElementAccumulator); + static_assert((L0A_TILE_SIZE * STAGES) <= L0A_SIZE, "L0TileShape exceeding the L0A space!"); + static_assert((L0B_TILE_SIZE * STAGES) <= L0B_SIZE, "L0TileShape exceeding the L0B space!"); + static_assert(L0C_TILE_SIZE <= L0C_SIZE, "L0TileShape exceeding the L0C space!"); + + static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N, + "The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet"); + static_assert(L0TileShape::K <= L1TileShape::K, "L0TileShape::K cannot exceed L1TileShape::K"); + + /// Construct + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + uint32_t l1AOffset = l1BufAddrStart; + uint32_t l1BOffset = l1BufAddrStart + L1A_SIZE * STAGES; + // Init buffers + for (uint32_t i = 0; i < STAGES; i++) { + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1A_SIZE * i); + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_SIZE * i); + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_PINGPONG_BUF_SIZE * i); + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_PINGPONG_BUF_SIZE * i); + + l1AEventList[i] = i; + l1BEventList[i] = i + STAGES; + l0AEventList[i] = i; + l0BEventList[i] = i + STAGES; + AscendC::SetFlag(l1AEventList[i]); + AscendC::SetFlag(l1BEventList[i]); + AscendC::SetFlag(l0AEventList[i]); + AscendC::SetFlag(l0BEventList[i]); + } + l0CTensor = resource.l0CBuf.template GetBufferByByte(0); + AscendC::SetFlag(EVENT_ID0); + } + + /// Destructor + CATLASS_DEVICE + ~BlockMmad() + { + for (uint32_t i = 0; i < STAGES; i++) { + AscendC::WaitFlag(l1AEventList[i]); + AscendC::WaitFlag(l1BEventList[i]); + AscendC::WaitFlag(l0AEventList[i]); + AscendC::WaitFlag(l0BEventList[i]); + } + AscendC::WaitFlag(EVENT_ID0); + } + + /// Perform a block-scoped matrix multiply-accumulate + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &gmBlockA, LayoutA const &layoutA, + AscendC::GlobalTensor const &gmBlockB, LayoutB const &layoutB, + AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutC, + AscendC::GlobalTensor const &gmNextBlockA, + AscendC::GlobalTensor const &gmNextBlockB, GemmCoord const &actualShape, + GemmCoord const &actualShapeNext, bool isFirstBlock, bool hasNextBlock) + { + uint32_t mRound = RoundUp(actualShape.m()); + uint32_t nRound = RoundUp(actualShape.n()); + + auto layoutAInL1 = LayoutAInL1::template MakeLayout(L1TileShape::M, L1TileShape::K); + auto layoutBInL1 = LayoutBInL1::template MakeLayout(L1TileShape::K, L1TileShape::N); + auto layoutInL0C = LayoutCInL0::MakeLayoutInL0C(MakeCoord(mRound, nRound)); + + uint32_t kTileCount = CeilDiv(actualShape.k()); + uint32_t kTileCountNext = CeilDiv(actualShapeNext.k()); + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::WaitFlag(EVENT_ID0); + } + uint32_t startTileIdx = 0; + if constexpr (ENABLE_SHUFFLE_K) { + startTileIdx = AscendC::GetBlockIdx(); + } + uint32_t firstTileIdx = startTileIdx % kTileCount; + uint32_t lastTileIdx = (startTileIdx + kTileCount - 1) % kTileCount; + uint32_t kActual = + (firstTileIdx < kTileCount - 1) ? L1TileShape::K : (actualShape.k() - firstTileIdx * L1TileShape::K); + uint32_t firstTileIdxNext = startTileIdx % kTileCountNext; + + uint32_t mPartLoop = CeilDiv(mRound); + uint32_t nPartLoop = CeilDiv(nRound); + + // k loop + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; kLoopIdx++) { + uint32_t shuffleKIdx = (startTileIdx + kLoopIdx) % kTileCount; + // Load first matrix A tile in total kernel loop from GM to L1 + if (shuffleKIdx == firstTileIdx && isFirstBlock) { + MatrixCoord gmTileAOffset{0, shuffleKIdx * L1TileShape::K}; + MatrixCoord gmTileBOffset{shuffleKIdx * L1TileShape::K, 0}; + auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)]; + // Load first matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListId]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); + copyGmToL1A(l1ATensorList[l1ListId], gmTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListId]); + + // Load first matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListId]); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); + copyGmToL1B(l1BTensorList[l1ListId], gmTileB, layoutBInL1, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListId]); + } + + uint32_t l1ListIdNext = (l1ListId + 1 < STAGES) ? (l1ListId + 1) : 0; + uint32_t kActualNext{0}; + + // preload next tile from GM to L1 + if (shuffleKIdx != lastTileIdx) { + uint32_t shuffleKIdxNext = (startTileIdx + kLoopIdx + 1) % kTileCount; + // Get L1 tensor for next stage + auto l1ATensor = l1ATensorList[l1ListIdNext]; + auto l1BTensor = l1BTensorList[l1ListIdNext]; + // Get GM tensor for next stage + kActualNext = (shuffleKIdxNext < kTileCount - 1) ? L1TileShape::K + : (actualShape.k() - shuffleKIdxNext * L1TileShape::K); + MatrixCoord gmTileAOffset{0, shuffleKIdxNext * L1TileShape::K}; + MatrixCoord gmTileBOffset{shuffleKIdxNext * L1TileShape::K, 0}; + auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)]; + + // load next matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActualNext)); + copyGmToL1A(l1ATensor, gmTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + + // load next matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActualNext, actualShape.n())); + copyGmToL1B(l1BTensor, gmTileB, layoutBInL1, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + } + if (shuffleKIdx == lastTileIdx && hasNextBlock) { + // Get L1 tensor for next stage + auto l1ATensor = l1ATensorList[l1ListIdNext]; + auto l1BTensor = l1BTensorList[l1ListIdNext]; + // Get GM tensor for next stage + kActualNext = (firstTileIdxNext < kTileCountNext - 1) + ? L1TileShape::K + : (actualShapeNext.k() - firstTileIdxNext * L1TileShape::K); + MatrixCoord gmTileAOffset{0, firstTileIdxNext * L1TileShape::K}; + MatrixCoord gmTileBOffset{firstTileIdxNext * L1TileShape::K, 0}; + auto gmTileA = gmNextBlockA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTileB = gmNextBlockB[layoutB.GetOffset(gmTileBOffset)]; + // load next matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShapeNext.m(), kActualNext)); + copyGmToL1A(l1ATensor, gmTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + + // load next matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActualNext, actualShapeNext.n())); + copyGmToL1B(l1BTensor, gmTileB, layoutBInL1, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + } + + // Get L1 tensor for current stage + auto l1ATensor = l1ATensorList[l1ListId]; + auto l1BTensor = l1BTensorList[l1ListId]; + + // Get the loop nums on L0 + uint32_t kPartLoop = CeilDiv(kActual); + + uint32_t l0ABufId = 0; + uint32_t l0BBufId = 0; + + for (int mPartIdx = 0; mPartIdx < mPartLoop; mPartIdx++) { + uint32_t mPartActual = + (mPartIdx < mPartLoop - 1) ? L0TileShape::M : (mRound - mPartIdx * L0TileShape::M); + + for (int kPartIdx = 0; kPartIdx < kPartLoop; kPartIdx++) { + uint32_t kPartActual = + (kPartIdx < kPartLoop - 1) ? L0TileShape::K : (kActual - kPartIdx * L0TileShape::K); + + // Locate the current tile on L0A + auto l0ATile = l0ATensorList[l0ABufId]; + LayoutAInL0 layoutAInL0 = LayoutAInL0::template MakeLayout(mPartActual, kPartActual); + // Locate the current tile of matrix A on L1 + MatrixCoord l1AOffset{mPartIdx * L0TileShape::M, kPartIdx * L0TileShape::K}; + auto l1ATile = l1ATensor[layoutAInL1.GetOffset(l1AOffset)]; + + AscendC::WaitFlag(l0AEventList[l0ABufId]); + if ((mPartIdx == 0) && (kPartIdx == 0)) { + AscendC::WaitFlag(l1AEventList[l1ListId]); + } + + copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, layoutAInL1); + + if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) { + AscendC::SetFlag(l1AEventList[l1ListId]); + } + + for (int nPartIdx = 0; nPartIdx < nPartLoop; nPartIdx++) { + uint32_t nPartActual = + (nPartIdx < nPartLoop - 1) ? L0TileShape::N : (nRound - nPartIdx * L0TileShape::N); + + // Locate the current tile on L0B + auto l0BTile = l0BTensorList[l0BBufId]; + LayoutBInL0 layoutBInL0 = LayoutBInL0::template MakeLayout(kPartActual, nPartActual); + // Locate the current tile of matrix B on L1 + MatrixCoord l1BOffset{kPartIdx * L0TileShape::K, nPartIdx * L0TileShape::N}; + auto l1BTile = l1BTensor[layoutBInL1.GetOffset(l1BOffset)]; + + // Wait for mmad finished + AscendC::WaitFlag(l0BEventList[l0BBufId]); + // If the current tile is the first one on the k&n axis, wait for loading matrix B from GM to L1 + if ((kPartIdx == 0) && (nPartIdx == 0)) { + AscendC::WaitFlag(l1BEventList[l1ListId]); + } + + // Load current tile from L1 to L0B + copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, layoutBInL1); + + // If the current tile is the last one on the k&n axis, notify to load matrix B from GM to L1 + if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + AscendC::SetFlag(l1BEventList[l1ListId]); + } + // Notify to do mmad + AscendC::SetFlag(EVENT_ID0); + + // Locate the current tile on L0C + MatrixCoord l0COffset{mPartIdx * L0TileShape::M, nPartIdx * L0TileShape::N}; + auto l0CTile = l0CTensor[layoutInL0C.GetOffset(l0COffset)]; + + // Compute the matrix multiplication on L0A and L0B and write the result to the accumulator + // Wait for loading L0B + AscendC::WaitFlag(EVENT_ID0); + + // If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0 + bool initC = ((kLoopIdx == 0) && (kPartIdx == 0)); + // If the unit flag is enabled, the unit flag is set according to the calculation progress + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if ((kLoopIdx == kTileCount - 1) && (mPartIdx == mPartLoop - 1) && + (kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + // Perform calculation operations + tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); + + // Notify to move the next L0B tile + AscendC::SetFlag(l0BEventList[l0BBufId]); + + l0BBufId = (l0BBufId + 1 < STAGES) ? (l0BBufId + 1) : 0; + } + AscendC::SetFlag(l0AEventList[l0ABufId]); + l0ABufId = (l0ABufId + 1 < STAGES) ? (l0ABufId + 1) : 0; + } + } + + l1ListId = l1ListIdNext; + kActual = kActualNext; + } + + // copy block out + LayoutC layoutBlock = layoutC.GetTileLayout(actualShape.GetCoordMN()); + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + copyL0CToGm(gmBlockC, l0CTensor, layoutBlock, layoutInL0C); + AscendC::SetFlag(EVENT_ID0); + } else { + copyL0CToGm(gmBlockC, l0CTensor, layoutBlock, layoutInL0C, 0b11); + } + } + +protected: + /// Data members + AscendC::LocalTensor l1ATensorList[STAGES]; + AscendC::LocalTensor l1BTensorList[STAGES]; + AscendC::LocalTensor l0ATensorList[STAGES]; + AscendC::LocalTensor l0BTensorList[STAGES]; + AscendC::LocalTensor l0CTensor; + + int32_t l1AEventList[STAGES]; + int32_t l1BEventList[STAGES]; + int32_t l0AEventList[STAGES]; + int32_t l0BEventList[STAGES]; + + uint32_t l1ListId{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_preload_async.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_preload_async.hpp new file mode 100644 index 00000000..4ed4ecaa --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_preload_async.hpp @@ -0,0 +1,404 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_ASYNC_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_ASYNC_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +namespace Catlass::Gemm::Block { + +template +struct BlockMmad, + L1TileShape_, L0TileShape_, AType_, BType_, CType_, BiasType_, TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2PreloadAsync; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t PRELOAD_STAGES = DispatchPolicy::PRELOAD_STAGES; + static constexpr uint32_t L1_STAGES = DispatchPolicy::L1_STAGES; + static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES; + static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES; + static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES; + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K; + + // L1 tile size + static constexpr uint32_t L1A_TILE_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_TILE_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + // L0 tile size + static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA); + static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB); + static constexpr uint32_t L0C_TILE_SIZE = L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator); + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + // Check L1TileShape + static_assert((L1A_TILE_SIZE + L1B_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE, + "L1TileShape exceeding the L1 space!"); + + // Check L0TileShape + static_assert(L0A_TILE_SIZE * L0A_STAGES <= ArchTag::L0A_SIZE, "L0TileShape exceeding the L0A space!"); + static_assert(L0B_TILE_SIZE * L0B_STAGES <= ArchTag::L0B_SIZE, "L0TileShape exceeding the L0B space!"); + static_assert(L0C_TILE_SIZE * L0C_STAGES <= ArchTag::L0C_SIZE, "L0TileShape exceeding the L0C space!"); + + static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N, + "The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet"); + static_assert(L0TileShape::K <= L1TileShape::K, "L0TileShape::K cannot exceed L1TileShape::K"); + + static constexpr auto L1A_LAYOUT = LayoutAInL1::template MakeLayout(L1TileShape::M, L1TileShape::K); + static constexpr auto L1B_LAYOUT = LayoutBInL1::template MakeLayout(L1TileShape::K, L1TileShape::N); + + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + InitL1(resource, l1BufAddrStart); + InitL0A(resource); + InitL0B(resource); + InitL0C(resource); + } + + CATLASS_DEVICE + ~BlockMmad() + { + SynchronizeBlock(); + for (uint32_t i = 0; i < L1_STAGES; ++i) { + AscendC::WaitFlag(l1AEventList[i]); + AscendC::WaitFlag(l1BEventList[i]); + } + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + AscendC::WaitFlag(l0AEventList[i]); + } + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + AscendC::WaitFlag(l0BEventList[i]); + } + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + AscendC::WaitFlag(l0CEventList[i]); + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &gmBlockA, LayoutA const &layoutA, + AscendC::GlobalTensor const &gmBlockB, LayoutB const &layoutB, + AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutC, + GemmCoord const &actualShape, Callback &&callback = Callback{}) + { + uint32_t kTileCount = CeilDiv(actualShape.k()); + + uint32_t mRound = RoundUp(actualShape.m()); + uint32_t nRound = RoundUp(actualShape.n()); + + uint32_t startTileIdx = 0; + if constexpr (ENABLE_SHUFFLE_K) { + startTileIdx = AscendC::GetBlockIdx() % kTileCount; + } + + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; ++kLoopIdx) { + uint32_t kTileIdx = (startTileIdx + kLoopIdx < kTileCount) ? (startTileIdx + kLoopIdx) + : (startTileIdx + kLoopIdx - kTileCount); + + uint32_t kActual = + (kTileIdx < kTileCount - 1) ? L1TileShape::K : (actualShape.k() - kTileIdx * L1TileShape::K); + + // Emission load instruction from GM to L1 + MatrixCoord gmTileAOffset{0, kTileIdx * L1TileShape::K}; + MatrixCoord gmTileBOffset{kTileIdx * L1TileShape::K, 0}; + auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)]; + // Load first matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListId]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); + copyGmToL1A(l1ATensorList[l1ListId], gmTileA, L1A_LAYOUT, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListId]); + // Load first matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListId]); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); + copyGmToL1B(l1BTensorList[l1ListId], gmTileB, L1B_LAYOUT, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListId]); + + // If the number of preload instructions reaches the upper limit, perform an mmad calculation on L1 tile + if (preloadCount == PRELOAD_STAGES) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + } + + // Store the current load status + uint32_t preloadL1TileMmadParamsId = (l1TileMmadParamsId + preloadCount < PRELOAD_STAGES) + ? (l1TileMmadParamsId + preloadCount) + : (l1TileMmadParamsId + preloadCount - PRELOAD_STAGES); + auto &l1TileMmadParams = l1TileMmadParamsList[preloadL1TileMmadParamsId]; + l1TileMmadParams.l1ListId = l1ListId; + l1TileMmadParams.mRound = mRound; + l1TileMmadParams.nRound = nRound; + l1TileMmadParams.kActual = kActual; + l1TileMmadParams.isKLoopFirst = (kLoopIdx == 0); + l1TileMmadParams.isKLoopLast = (kLoopIdx == kTileCount - 1); + if (kLoopIdx == kTileCount - 1) { + l1TileMmadParams.gmBlockC = gmBlockC; + l1TileMmadParams.layoutCInGm = layoutC.GetTileLayout(actualShape.GetCoordMN()); + l1TileMmadParams.callback = callback; + } + + if (preloadCount < PRELOAD_STAGES) { + ++preloadCount; + } else { + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; + } + l1ListId = (l1ListId + 1 < L1_STAGES) ? (l1ListId + 1) : 0; + } + } + + CATLASS_DEVICE + void SynchronizeBlock() + { + while (preloadCount > 0) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; + --preloadCount; + } + } + +private: + struct L1TileMmadParams { + uint32_t l1ListId; + uint32_t mRound; + uint32_t nRound; + uint32_t kActual; + bool isKLoopFirst; + bool isKLoopLast; + AscendC::GlobalTensor gmBlockC; + LayoutC layoutCInGm; + Callback callback; + + CATLASS_DEVICE + L1TileMmadParams() = default; + }; + + CATLASS_DEVICE + void InitL1(Arch::Resource &resource, uint32_t l1BufAddrStart) + { + uint32_t l1AOffset = l1BufAddrStart; + uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * L1_STAGES; + for (uint32_t i = 0; i < L1_STAGES; ++i) { + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1A_TILE_SIZE * i); + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_TILE_SIZE * i); + l1AEventList[i] = i; + l1BEventList[i] = i + L1_STAGES; + AscendC::SetFlag(l1AEventList[i]); + AscendC::SetFlag(l1BEventList[i]); + } + } + + CATLASS_DEVICE + void InitL0A(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_TILE_SIZE * i); + l0AEventList[i] = i; + AscendC::SetFlag(l0AEventList[i]); + } + } + + CATLASS_DEVICE + void InitL0B(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_TILE_SIZE * i); + l0BEventList[i] = i + L0A_STAGES; + AscendC::SetFlag(l0BEventList[i]); + } + } + + CATLASS_DEVICE + void InitL0C(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + l0CTensorList[i] = resource.l0CBuf.template GetBufferByByte(L0C_TILE_SIZE * i); + l0CEventList[i] = i; + AscendC::SetFlag(l0CEventList[i]); + } + } + + CATLASS_DEVICE + void L1TileMmad(L1TileMmadParams const ¶ms) + { + uint32_t mPartLoop = CeilDiv(params.mRound); + uint32_t nPartLoop = CeilDiv(params.nRound); + uint32_t kPartLoop = CeilDiv(params.kActual); + auto &l1ATensor = l1ATensorList[params.l1ListId]; + auto &l1BTensor = l1BTensorList[params.l1ListId]; + + auto &l0CTensor = l0CTensorList[l0CListId]; + LayoutCInL0 layoutCInL0 = LayoutCInL0::MakeLayoutInL0C(MakeCoord(params.mRound, params.nRound)); + + if constexpr (!ENABLE_UNIT_FLAG) { + if (params.isKLoopFirst) { + AscendC::WaitFlag(l0CEventList[l0CListId]); + } + } + + for (uint32_t mPartIdx = 0; mPartIdx < mPartLoop; ++mPartIdx) { + uint32_t mPartActual = + (mPartIdx < mPartLoop - 1) ? L0TileShape::M : (params.mRound - mPartIdx * L0TileShape::M); + + for (uint32_t kPartIdx = 0; kPartIdx < kPartLoop; ++kPartIdx) { + uint32_t kPartActual = + (kPartIdx < kPartLoop - 1) ? L0TileShape::K : (params.kActual - kPartIdx * L0TileShape::K); + + auto &l0ATile = l0ATensorList[l0AListId]; + auto layoutAInL0 = LayoutAInL0::template MakeLayout(mPartActual, kPartActual); + auto l1AOffset = MakeCoord(mPartIdx, kPartIdx) * L0TileShape::ToCoordMK(); + auto l1ATile = l1ATensor[L1A_LAYOUT.GetOffset(l1AOffset)]; + + AscendC::WaitFlag(l0AEventList[l0AListId]); + if ((mPartIdx == 0) && (kPartIdx == 0)) { + AscendC::WaitFlag(l1AEventList[params.l1ListId]); + } + copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, L1A_LAYOUT); + if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) { + AscendC::SetFlag(l1AEventList[params.l1ListId]); + } + + for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; ++nPartIdx) { + uint32_t nPartActual = + (nPartIdx < nPartLoop - 1) ? L0TileShape::N : (params.nRound - nPartIdx * L0TileShape::N); + + auto &l0BTile = l0BTensorList[l0BListId]; + auto layoutBInL0 = LayoutBInL0::template MakeLayout(kPartActual, nPartActual); + auto l1BOffset = MakeCoord(kPartIdx, nPartIdx) * L0TileShape::ToCoordKN(); + auto l1BTile = l1BTensor[L1B_LAYOUT.GetOffset(l1BOffset)]; + + AscendC::WaitFlag(l0BEventList[l0BListId]); + if ((kPartIdx == 0) && (nPartIdx == 0)) { + AscendC::WaitFlag(l1BEventList[params.l1ListId]); + } + copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, L1B_LAYOUT); + if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + AscendC::SetFlag(l1BEventList[params.l1ListId]); + } + + AscendC::SetFlag(EVENT_ID0); + + auto l0COffset = MakeCoord(mPartIdx, nPartIdx) * L0TileShape::ToCoordMN(); + auto l0CTile = l0CTensor[layoutCInL0.GetOffset(l0COffset)]; + + AscendC::WaitFlag(EVENT_ID0); + // If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0 + bool initC = (params.isKLoopFirst && (kPartIdx == 0)); + // If the unit flag is enabled, the unit flag is set according to the calculation progress + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if (params.isKLoopLast && (mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1) && + (nPartIdx == nPartLoop - 1)) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); + + AscendC::SetFlag(l0BEventList[l0BListId]); + l0BListId = (l0BListId + 1 < L0B_STAGES) ? (l0BListId + 1) : 0; + } + AscendC::SetFlag(l0AEventList[l0AListId]); + l0AListId = (l0AListId + 1 < L0A_STAGES) ? (l0AListId + 1) : 0; + } + } + + if (params.isKLoopLast) { + auto layoutCInGm = params.layoutCInGm; + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(l0CEventList[l0CListId]); + AscendC::WaitFlag(l0CEventList[l0CListId]); + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0); + AscendC::SetFlag(l0CEventList[l0CListId]); + } else { + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0, 0b11); + } + l0CListId = (l0CListId + 1 < L0C_STAGES) ? (l0CListId + 1) : 0; + + if (params.callback) { + params.callback(); + } + } + } + + AscendC::LocalTensor l1ATensorList[L1_STAGES]; + AscendC::LocalTensor l1BTensorList[L1_STAGES]; + int32_t l1AEventList[L1_STAGES]; + int32_t l1BEventList[L1_STAGES]; + uint32_t l1ListId{0}; + + AscendC::LocalTensor l0ATensorList[L0A_STAGES]; + int32_t l0AEventList[L0A_STAGES]; + uint32_t l0AListId{0}; + + AscendC::LocalTensor l0BTensorList[L0B_STAGES]; + int32_t l0BEventList[L0B_STAGES]; + uint32_t l0BListId{0}; + + AscendC::LocalTensor l0CTensorList[L0C_STAGES_]; + int32_t l0CEventList[L0C_STAGES_]; + uint32_t l0CListId{0}; + + L1TileMmadParams l1TileMmadParamsList[PRELOAD_STAGES]; + uint32_t l1TileMmadParamsId{0}; + uint32_t preloadCount{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_ASYNC_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_preload_async_with_callback.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_preload_async_with_callback.hpp new file mode 100644 index 00000000..9f0ce5ae --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_preload_async_with_callback.hpp @@ -0,0 +1,407 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_ASYNC_WITH_CALLBACK_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_ASYNC_WITH_CALLBACK_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +namespace Catlass::Gemm::Block { + +template +struct BlockMmad, + L1TileShape_, L0TileShape_, AType_, BType_, CType_, BiasType_, TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2PreloadAsyncWithCallback; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t PRELOAD_STAGES = DispatchPolicy::PRELOAD_STAGES; + static constexpr uint32_t L1_STAGES = DispatchPolicy::L1_STAGES; + static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES; + static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES; + static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES; + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K; + + // L1 tile size + static constexpr uint32_t L1A_TILE_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_TILE_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + // L0 tile size + static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA); + static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB); + static constexpr uint32_t L0C_TILE_SIZE = L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator); + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + // Check L1TileShape + static_assert((L1A_TILE_SIZE + L1B_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE, + "L1TileShape exceeding the L1 space!"); + + // Check L0TileShape + static_assert(L0A_TILE_SIZE * L0A_STAGES <= ArchTag::L0A_SIZE, "L0TileShape exceeding the L0A space!"); + static_assert(L0B_TILE_SIZE * L0B_STAGES <= ArchTag::L0B_SIZE, "L0TileShape exceeding the L0B space!"); + static_assert(L0C_TILE_SIZE * L0C_STAGES <= ArchTag::L0C_SIZE, "L0TileShape exceeding the L0C space!"); + + static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N, + "The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet"); + static_assert(L0TileShape::K <= L1TileShape::K, "L0TileShape::K cannot exceed L1TileShape::K"); + + static constexpr auto L1A_LAYOUT = LayoutAInL1::template MakeLayout(L1TileShape::M, L1TileShape::K); + static constexpr auto L1B_LAYOUT = LayoutBInL1::template MakeLayout(L1TileShape::K, L1TileShape::N); + + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + InitL1(resource, l1BufAddrStart); + InitL0A(resource); + InitL0B(resource); + InitL0C(resource); + } + + CATLASS_DEVICE + ~BlockMmad() + { + SynchronizeBlock(); + for (uint32_t i = 0; i < L1_STAGES; ++i) { + AscendC::WaitFlag(l1AEventList[i]); + AscendC::WaitFlag(l1BEventList[i]); + } + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + AscendC::WaitFlag(l0AEventList[i]); + } + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + AscendC::WaitFlag(l0BEventList[i]); + } + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + AscendC::WaitFlag(l0CEventList[i]); + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &gmBlockA, LayoutA const &layoutA, + AscendC::GlobalTensor const &gmBlockB, LayoutB const &layoutB, + AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutC, + GemmCoord const &actualShape, Callback const &callbackBeforeFixpipe, + Callback const &callbackAfterFixpipe) + { + uint32_t kTileCount = CeilDiv(actualShape.k()); + + uint32_t mRound = RoundUp(actualShape.m()); + uint32_t nRound = RoundUp(actualShape.n()); + + uint32_t startTileIdx = 0; + if constexpr (ENABLE_SHUFFLE_K) { + startTileIdx = AscendC::GetBlockIdx() % kTileCount; + } + + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; ++kLoopIdx) { + uint32_t kTileIdx = (startTileIdx + kLoopIdx < kTileCount) ? (startTileIdx + kLoopIdx) + : (startTileIdx + kLoopIdx - kTileCount); + + uint32_t kActual = + (kTileIdx < kTileCount - 1) ? L1TileShape::K : (actualShape.k() - kTileIdx * L1TileShape::K); + + // Emission load instruction from GM to L1 + MatrixCoord gmTileAOffset{0, kTileIdx * L1TileShape::K}; + MatrixCoord gmTileBOffset{kTileIdx * L1TileShape::K, 0}; + auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)]; + // Load first matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListId]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); + copyGmToL1A(l1ATensorList[l1ListId], gmTileA, L1A_LAYOUT, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListId]); + // Load first matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListId]); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); + copyGmToL1B(l1BTensorList[l1ListId], gmTileB, L1B_LAYOUT, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListId]); + + // If the number of preload instructions reaches the upper limit, perform an mmad calculation on L1 tile + if (preloadCount == PRELOAD_STAGES) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + } + + // Store the current load status + uint32_t preloadL1TileMmadParamsId = (l1TileMmadParamsId + preloadCount < PRELOAD_STAGES) + ? (l1TileMmadParamsId + preloadCount) + : (l1TileMmadParamsId + preloadCount - PRELOAD_STAGES); + auto &l1TileMmadParams = l1TileMmadParamsList[preloadL1TileMmadParamsId]; + l1TileMmadParams.l1ListId = l1ListId; + l1TileMmadParams.mRound = mRound; + l1TileMmadParams.nRound = nRound; + l1TileMmadParams.kActual = kActual; + l1TileMmadParams.isKLoopFirst = (kLoopIdx == 0); + l1TileMmadParams.isKLoopLast = (kLoopIdx == kTileCount - 1); + if (kLoopIdx == kTileCount - 1) { + l1TileMmadParams.gmBlockC = gmBlockC; + l1TileMmadParams.layoutCInGm = layoutC.GetTileLayout(actualShape.GetCoordMN()); + l1TileMmadParams.callbackBeforeFixpipe = callbackBeforeFixpipe; + l1TileMmadParams.callbackAfterFixpipe = callbackAfterFixpipe; + } + + if (preloadCount < PRELOAD_STAGES) { + ++preloadCount; + } else { + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; + } + l1ListId = (l1ListId + 1 < L1_STAGES) ? (l1ListId + 1) : 0; + } + } + + CATLASS_DEVICE + void SynchronizeBlock() + { + while (preloadCount > 0) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; + --preloadCount; + } + } + +private: + struct L1TileMmadParams { + uint32_t l1ListId; + uint32_t mRound; + uint32_t nRound; + uint32_t kActual; + bool isKLoopFirst; + bool isKLoopLast; + AscendC::GlobalTensor gmBlockC; + LayoutC layoutCInGm; + Callback callbackBeforeFixpipe; + Callback callbackAfterFixpipe; + + CATLASS_DEVICE + L1TileMmadParams() = default; + }; + + CATLASS_DEVICE + void InitL1(Arch::Resource &resource, uint32_t l1BufAddrStart) + { + uint32_t l1AOffset = l1BufAddrStart; + uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * L1_STAGES; + for (uint32_t i = 0; i < L1_STAGES; ++i) { + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1A_TILE_SIZE * i); + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_TILE_SIZE * i); + l1AEventList[i] = i; + l1BEventList[i] = i + L1_STAGES; + AscendC::SetFlag(l1AEventList[i]); + AscendC::SetFlag(l1BEventList[i]); + } + } + + CATLASS_DEVICE + void InitL0A(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_TILE_SIZE * i); + l0AEventList[i] = i; + AscendC::SetFlag(l0AEventList[i]); + } + } + + CATLASS_DEVICE + void InitL0B(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_TILE_SIZE * i); + l0BEventList[i] = i + L0A_STAGES; + AscendC::SetFlag(l0BEventList[i]); + } + } + + CATLASS_DEVICE + void InitL0C(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + l0CTensorList[i] = resource.l0CBuf.template GetBufferByByte(L0C_TILE_SIZE * i); + l0CEventList[i] = i; + AscendC::SetFlag(l0CEventList[i]); + } + } + + CATLASS_DEVICE + void L1TileMmad(L1TileMmadParams const ¶ms) + { + uint32_t mPartLoop = CeilDiv(params.mRound); + uint32_t nPartLoop = CeilDiv(params.nRound); + uint32_t kPartLoop = CeilDiv(params.kActual); + auto &l1ATensor = l1ATensorList[params.l1ListId]; + auto &l1BTensor = l1BTensorList[params.l1ListId]; + + auto &l0CTensor = l0CTensorList[l0CListId]; + LayoutCInL0 layoutCInL0 = LayoutCInL0::MakeLayoutInL0C(MakeCoord(params.mRound, params.nRound)); + + if constexpr (!ENABLE_UNIT_FLAG) { + if (params.isKLoopFirst) { + AscendC::WaitFlag(l0CEventList[l0CListId]); + } + } + + for (uint32_t mPartIdx = 0; mPartIdx < mPartLoop; ++mPartIdx) { + uint32_t mPartActual = + (mPartIdx < mPartLoop - 1) ? L0TileShape::M : (params.mRound - mPartIdx * L0TileShape::M); + + for (uint32_t kPartIdx = 0; kPartIdx < kPartLoop; ++kPartIdx) { + uint32_t kPartActual = + (kPartIdx < kPartLoop - 1) ? L0TileShape::K : (params.kActual - kPartIdx * L0TileShape::K); + + auto &l0ATile = l0ATensorList[l0AListId]; + auto layoutAInL0 = LayoutAInL0::template MakeLayout(mPartActual, kPartActual); + auto l1AOffset = MakeCoord(mPartIdx, kPartIdx) * L0TileShape::ToCoordMK(); + auto l1ATile = l1ATensor[L1A_LAYOUT.GetOffset(l1AOffset)]; + + AscendC::WaitFlag(l0AEventList[l0AListId]); + if ((mPartIdx == 0) && (kPartIdx == 0)) { + AscendC::WaitFlag(l1AEventList[params.l1ListId]); + } + copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, L1A_LAYOUT); + if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) { + AscendC::SetFlag(l1AEventList[params.l1ListId]); + } + + for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; ++nPartIdx) { + uint32_t nPartActual = + (nPartIdx < nPartLoop - 1) ? L0TileShape::N : (params.nRound - nPartIdx * L0TileShape::N); + + auto &l0BTile = l0BTensorList[l0BListId]; + auto layoutBInL0 = LayoutBInL0::template MakeLayout(kPartActual, nPartActual); + auto l1BOffset = MakeCoord(kPartIdx, nPartIdx) * L0TileShape::ToCoordKN(); + auto l1BTile = l1BTensor[L1B_LAYOUT.GetOffset(l1BOffset)]; + + AscendC::WaitFlag(l0BEventList[l0BListId]); + if ((kPartIdx == 0) && (nPartIdx == 0)) { + AscendC::WaitFlag(l1BEventList[params.l1ListId]); + } + copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, L1B_LAYOUT); + if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + AscendC::SetFlag(l1BEventList[params.l1ListId]); + } + + AscendC::SetFlag(EVENT_ID0); + + auto l0COffset = MakeCoord(mPartIdx, nPartIdx) * L0TileShape::ToCoordMN(); + auto l0CTile = l0CTensor[layoutCInL0.GetOffset(l0COffset)]; + + AscendC::WaitFlag(EVENT_ID0); + // If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0 + bool initC = (params.isKLoopFirst && (kPartIdx == 0)); + // If the unit flag is enabled, the unit flag is set according to the calculation progress + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if (params.isKLoopLast && (mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1) && + (nPartIdx == nPartLoop - 1)) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); + + AscendC::SetFlag(l0BEventList[l0BListId]); + l0BListId = (l0BListId + 1 < L0B_STAGES) ? (l0BListId + 1) : 0; + } + AscendC::SetFlag(l0AEventList[l0AListId]); + l0AListId = (l0AListId + 1 < L0A_STAGES) ? (l0AListId + 1) : 0; + } + } + + if (params.isKLoopLast) { + auto layoutCInGm = params.layoutCInGm; + + params.callbackBeforeFixpipe(); + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(l0CEventList[l0CListId]); + AscendC::WaitFlag(l0CEventList[l0CListId]); + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0); + AscendC::SetFlag(l0CEventList[l0CListId]); + } else { + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0, 0b11); + } + l0CListId = (l0CListId + 1 < L0C_STAGES) ? (l0CListId + 1) : 0; + + params.callbackAfterFixpipe(); + } + } + + AscendC::LocalTensor l1ATensorList[L1_STAGES]; + AscendC::LocalTensor l1BTensorList[L1_STAGES]; + int32_t l1AEventList[L1_STAGES]; + int32_t l1BEventList[L1_STAGES]; + uint32_t l1ListId{0}; + + AscendC::LocalTensor l0ATensorList[L0A_STAGES]; + int32_t l0AEventList[L0A_STAGES]; + uint32_t l0AListId{0}; + + AscendC::LocalTensor l0BTensorList[L0B_STAGES]; + int32_t l0BEventList[L0B_STAGES]; + uint32_t l0BListId{0}; + + AscendC::LocalTensor l0CTensorList[L0C_STAGES_]; + int32_t l0CEventList[L0C_STAGES_]; + uint32_t l0CListId{0}; + + L1TileMmadParams l1TileMmadParamsList[PRELOAD_STAGES]; + uint32_t l1TileMmadParamsId{0}; + uint32_t preloadCount{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_ASYNC_WITH_CALLBACK_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_preload_tla.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_preload_tla.hpp new file mode 100644 index 00000000..0c3b0d08 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_preload_tla.hpp @@ -0,0 +1,397 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_TLA_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_TLA_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" +#include "tla/layout.hpp" +#include "tla/tensor.hpp" + +namespace Catlass::Gemm::Block { + +template +struct BlockMmadTla, L1TileShape_, L0TileShape_, TensorA_, + TensorB_, TensorC_, TensorBias_, TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2Preload; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using TensorA = TensorA_; + using TensorB = TensorB_; + using TensorC = TensorC_; + using ElementA = typename TensorA::Element; + using LayoutA = typename TensorA::Layout; + using ElementB = typename TensorB::Element; + using LayoutB = typename TensorB::Layout; + using ElementC = typename TensorC::Element; + using LayoutC = typename TensorC::Layout; + + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = typename CopyL0CToGm::ElementSrc; + + using LayoutTagL1A = typename TileCopy_::LayoutTagL1A; + using LayoutTagL1B = typename TileCopy_::LayoutTagL1B; + using LayoutTagL0A = typename TileCopy_::LayoutTagL0A; + using LayoutTagL0B = typename TileCopy_::LayoutTagL0B; + + using L1AAlignHelper = typename TileCopy_::L1AAlignHelper; + using L1BAlignHelper = typename TileCopy_::L1BAlignHelper; + + static_assert(tla::is_tuple::value && tla::is_static::value, + "L1TileShape must be tla::tuple and static!"); + static_assert(tla::is_tuple::value && tla::is_static::value, + "L0TileShape must be tla::tuple and static!"); + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K; + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L1_TILE_M = tla::get<0>(L1TileShape{}); + static constexpr uint32_t L1_TILE_N = tla::get<1>(L1TileShape{}); + static constexpr uint32_t L1_TILE_K = tla::get<2>(L1TileShape{}); + static constexpr uint32_t L0_TILE_M = tla::get<0>(L0TileShape{}); + static constexpr uint32_t L0_TILE_N = tla::get<1>(L0TileShape{}); + static constexpr uint32_t L0_TILE_K = tla::get<2>(L0TileShape{}); + + // L1 tile size + static constexpr uint32_t L1A_TILE_SIZE = L1_TILE_M * L1_TILE_K * sizeof(ElementA); + static constexpr uint32_t L1B_TILE_SIZE = L1_TILE_N * L1_TILE_K * sizeof(ElementB); + // L0 tile size + static constexpr uint32_t L0A_TILE_SIZE = L0_TILE_M * L0_TILE_K * sizeof(ElementA); + static constexpr uint32_t L0B_TILE_SIZE = L0_TILE_K * L0_TILE_N * sizeof(ElementB); + static constexpr uint32_t L0C_TILE_SIZE = L1_TILE_M * L1_TILE_N * sizeof(ElementAccumulator); + + // Check LayoutC + static_assert(tla::detail::isRowMajor::value, "LayoutC only support RowMajor yet!"); + + // Check L1TileShape + static_assert((L1A_TILE_SIZE + L1B_TILE_SIZE) * STAGES <= ArchTag::L1_SIZE, "L1TileShape exceeding the L1 space!"); + + // Check L0TileShape + static_assert(L0A_TILE_SIZE * STAGES <= ArchTag::L0A_SIZE, "L0TileShape exceeding the L0A space!"); + static_assert(L0B_TILE_SIZE * STAGES <= ArchTag::L0B_SIZE, "L0TileShape exceeding the L0B space!"); + static_assert(L0C_TILE_SIZE <= ArchTag::L0C_SIZE, "L0TileShape exceeding the L0C space!"); + + static_assert(L1_TILE_M == L0_TILE_M && L1_TILE_N == L0_TILE_N, + "The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet"); + static_assert(L0_TILE_K <= L1_TILE_K, "L0TileShape::K cannot exceed L1TileShape::K"); + + static constexpr auto L1A_LAYOUT = tla::MakeLayout(L1_TILE_M, L1_TILE_K); + static constexpr auto L1B_LAYOUT = tla::MakeLayout(L1_TILE_K, L1_TILE_N); + + /// Construct + CATLASS_DEVICE + BlockMmadTla(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + uint32_t l1AOffset = l1BufAddrStart; + uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * STAGES; + // Init buffers + for (uint32_t i = 0; i < STAGES; i++) { + // Assign L1/L0A/L0B space for each stages + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1A_TILE_SIZE * i); + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_TILE_SIZE * i); + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_TILE_SIZE * i); + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_TILE_SIZE * i); + + // Assign event ID for each stages + l1AEventList[i] = i; + l1BEventList[i] = i + STAGES; + l0AEventList[i] = i; + l0BEventList[i] = i + STAGES; + + // The event id that needs to be set before the loop + AscendC::SetFlag(l1AEventList[i]); + AscendC::SetFlag(l1BEventList[i]); + AscendC::SetFlag(l0AEventList[i]); + AscendC::SetFlag(l0BEventList[i]); + } + l0CTensor = resource.l0CBuf.template GetBufferByByte(0); + AscendC::SetFlag(EVENT_ID0); + } + + /// Destructor + CATLASS_DEVICE + ~BlockMmadTla() + { + for (uint32_t i = 0; i < STAGES; i++) { + AscendC::WaitFlag(l1AEventList[i]); + AscendC::WaitFlag(l1BEventList[i]); + AscendC::WaitFlag(l0AEventList[i]); + AscendC::WaitFlag(l0BEventList[i]); + } + AscendC::WaitFlag(EVENT_ID0); + } + + /// Perform a block-scoped matrix multiply-accumulate + template + CATLASS_DEVICE void operator()(TensorA &tensorA, TensorB &tensorB, TensorC &tensorC, TensorA &tensorNextA, + TensorB &tensorNextB, GemmCoord const &actualShape, GemmCoord const &actualShapeNext, + bool isFirstBlock, bool hasNextBlock) + { + uint32_t mBlockActual = actualShape.m(); + uint32_t kBlockActual = actualShape.k(); + uint32_t nBlockActual = actualShape.n(); + uint32_t mNextBlockActual = actualShapeNext.m(); + uint32_t kNextBlockActual = actualShapeNext.k(); + uint32_t nNextBlockActual = actualShapeNext.n(); + + uint32_t mRound = RoundUp(mBlockActual); + uint32_t nRound = RoundUp(nBlockActual); + + auto layoutInL0C = tla::MakeLayoutL0C(mRound, nRound); + auto tensorL0C = tla::MakeTensor(l0CTensor, layoutInL0C, Arch::PositionL0C{}); + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::WaitFlag(EVENT_ID0); + } + uint32_t startTileIdx = 0; + if constexpr (ENABLE_SHUFFLE_K) { + startTileIdx = AscendC::GetBlockIdx(); + } + + uint32_t kTileCount = CeilDiv(kBlockActual); + uint32_t firstTileIdx = startTileIdx % kTileCount; + uint32_t lastTileIdx = (startTileIdx + kTileCount - 1) % kTileCount; + uint32_t kActual = (firstTileIdx < kTileCount - 1) ? L1_TILE_K : (kBlockActual - firstTileIdx * L1_TILE_K); + uint32_t kTileCountNext = CeilDiv(kNextBlockActual); + uint32_t firstTileIdxNext = startTileIdx % kTileCountNext; + + if (isFirstBlock) { + // load first matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListId]); + auto tensorL1A = tla::MakeTensor(l1ATensorList[l1ListId], L1A_LAYOUT, Arch::PositionL1{}); + auto tensorTileA = + GetTile(tensorA, tla::MakeCoord(0, firstTileIdx * L1_TILE_K), tla::MakeShape(mBlockActual, kActual)); + copyGmToL1A(tensorL1A, tensorTileA, tla::MakeShape(mBlockActual, kActual)); + AscendC::SetFlag(l1AEventList[l1ListId]); + + // load first matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListId]); + auto tensorL1B = tla::MakeTensor(l1BTensorList[l1ListId], L1B_LAYOUT, Arch::PositionL1{}); + auto tensorTileB = + GetTile(tensorB, tla::MakeCoord(firstTileIdx * L1_TILE_K, 0), tla::MakeShape(kActual, nBlockActual)); + copyGmToL1B(tensorL1B, tensorTileB, tla::MakeShape(kActual, nBlockActual)); + AscendC::SetFlag(l1BEventList[l1ListId]); + } + + uint32_t mPartLoop = CeilDiv(mRound); + uint32_t nPartLoop = CeilDiv(nRound); + + // main loop + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; kLoopIdx++) { + uint32_t shuffleKIdx = (startTileIdx + kLoopIdx) % kTileCount; + uint32_t l1ListIdNext = (l1ListId + 1 < STAGES) ? (l1ListId + 1) : 0; + uint32_t kActualNext{0}; + + if (shuffleKIdx != lastTileIdx) { + uint32_t shuffleKIdxNext = (startTileIdx + kLoopIdx + 1) % kTileCount; + kActualNext = + (shuffleKIdxNext < kTileCount - 1) ? L1_TILE_K : (kBlockActual - shuffleKIdxNext * L1_TILE_K); + + auto tensorL1A = tla::MakeTensor(l1ATensorList[l1ListIdNext], L1A_LAYOUT, Arch::PositionL1{}); + auto tensorL1B = tla::MakeTensor(l1BTensorList[l1ListIdNext], L1B_LAYOUT, Arch::PositionL1{}); + auto tensorTileA = GetTile(tensorA, tla::MakeCoord(0, shuffleKIdxNext * L1_TILE_K), + tla::MakeShape(mBlockActual, kActualNext)); + auto tensorTileB = GetTile(tensorB, tla::MakeCoord(shuffleKIdxNext * L1_TILE_K, 0), + tla::MakeShape(kActualNext, nBlockActual)); + + // load next matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + copyGmToL1A(tensorL1A, tensorTileA, tla::MakeShape(mBlockActual, kActualNext)); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + + // load next matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + copyGmToL1B(tensorL1B, tensorTileB, tla::MakeShape(kActualNext, nBlockActual)); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + } + + // preload next tile from GM to L1 + if (shuffleKIdx == lastTileIdx && hasNextBlock) { + kActualNext = (firstTileIdxNext < kTileCountNext - 1) + ? L1_TILE_K + : (kNextBlockActual - firstTileIdxNext * L1_TILE_K); + + // Get L1 tensor for next stage + auto tensorL1A = tla::MakeTensor(l1ATensorList[l1ListIdNext], L1A_LAYOUT, Arch::PositionL1{}); + auto tensorL1B = tla::MakeTensor(l1BTensorList[l1ListIdNext], L1B_LAYOUT, Arch::PositionL1{}); + // Get GM tile for next stage + auto tensorTileA = GetTile(tensorNextA, tla::MakeCoord(0, firstTileIdxNext * L1_TILE_K), + tla::MakeShape(mNextBlockActual, kActualNext)); + auto tensorTileB = GetTile(tensorNextB, tla::MakeCoord(firstTileIdxNext * L1_TILE_K, 0), + tla::MakeShape(kActualNext, nNextBlockActual)); + + // load next matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + copyGmToL1A(tensorL1A, tensorTileA, tla::MakeShape(mNextBlockActual, kActualNext)); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + + // load next matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + copyGmToL1B(tensorL1B, tensorTileB, tla::MakeShape(kActualNext, nNextBlockActual)); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + } + + // Get L1 tensor for current stage + auto l1ATensor = l1ATensorList[l1ListId]; + auto l1BTensor = l1BTensorList[l1ListId]; + auto tensorL1A = tla::MakeTensor(l1ATensor, L1A_LAYOUT, Arch::PositionL1{}); + auto tensorL1B = tla::MakeTensor(l1BTensor, L1B_LAYOUT, Arch::PositionL1{}); + // Get the loop nums on L0 + uint32_t kPartLoop = CeilDiv(kActual); + + for (int mPartIdx = 0; mPartIdx < mPartLoop; mPartIdx++) { + uint32_t mPartActual = (mPartIdx < mPartLoop - 1) ? L0_TILE_M : (mRound - mPartIdx * L0_TILE_M); + + for (int kPartIdx = 0; kPartIdx < kPartLoop; kPartIdx++) { + uint32_t kPartActual = (kPartIdx < kPartLoop - 1) ? L0_TILE_K : (kActual - kPartIdx * L0_TILE_K); + + // Locate the current tile on L0A + auto l0ATile = l0ATensorList[l0AListId]; + auto layoutAInL0 = tla::MakeLayout(mPartActual, kPartActual); + auto tensorL0A = tla::MakeTensor(l0ATile, layoutAInL0, Arch::PositionL0A{}); + // Locate the current tile of matrix A on L1 + auto tensorTileL1A = GetTile(tensorL1A, tla::MakeCoord(mPartIdx * L0_TILE_M, kPartIdx * L0_TILE_K), + tla::MakeShape(mPartActual, kPartActual)); + + AscendC::WaitFlag(l0AEventList[l0AListId]); + if ((mPartIdx == 0) && (kPartIdx == 0)) { + AscendC::WaitFlag(l1AEventList[l1ListId]); + } + + // Load current tile from L1 to L0A + copyL1ToL0A(tensorL0A, tensorTileL1A); + + if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) { + AscendC::SetFlag(l1AEventList[l1ListId]); + } + + for (int nPartIdx = 0; nPartIdx < nPartLoop; nPartIdx++) { + uint32_t nPartActual = (nPartIdx < nPartLoop - 1) ? L0_TILE_N : (nRound - nPartIdx * L0_TILE_N); + + // Locate the current tile on L0B + auto l0BTile = l0BTensorList[l0BListId]; + auto layoutBInL0 = tla::MakeLayout(kPartActual, nPartActual); + auto tensorL0B = tla::MakeTensor(l0BTile, layoutBInL0, Arch::PositionL0B{}); + // Locate the current tile of matrix B on L1 + auto tensorTileL1B = + GetTile(tensorL1B, tla::MakeCoord(kPartIdx * L0_TILE_K, nPartIdx * L0_TILE_N), + tla::MakeShape(kPartActual, nPartActual)); + + // Wait for mmad finished + AscendC::WaitFlag(l0BEventList[l0BListId]); + // If the current tile is the first one on the k&n axis, wait for loading matrix B from GM to L1 + if ((kPartIdx == 0) && (nPartIdx == 0)) { + AscendC::WaitFlag(l1BEventList[l1ListId]); + } + + // Load current tile from L1 to L0B + copyL1ToL0B(tensorL0B, tensorTileL1B); + + // If the current tile is the last one on the k&n axis, notify to load matrix B from GM to L1 + if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + AscendC::SetFlag(l1BEventList[l1ListId]); + } + // Notify to do mmad + AscendC::SetFlag(EVENT_ID0); + + // Locate the current tile on L0C + auto tensorTileL0C = + GetTile(tensorL0C, tla::MakeCoord(mPartIdx * L0_TILE_M, nPartIdx * L0_TILE_N), + tla::MakeShape(mPartActual, nPartActual)); + + // Compute the matrix multiplication on L0A and L0B and write the result to the accumulator + // Wait for loading L0B + AscendC::WaitFlag(EVENT_ID0); + + // If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0 + bool initC = ((kLoopIdx == 0) && (kPartIdx == 0)); + // If the unit flag is enabled, the unit flag is set according to the calculation progress + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if ((kLoopIdx == kTileCount - 1) && (mPartIdx == mPartLoop - 1) && + (kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + // Perform calculation operations + tileMmad(tensorTileL0C, tensorL0A, tensorL0B, mPartActual, nPartActual, kPartActual, initC, + unitFlag); + + // Notify to move the next L0B tile + AscendC::SetFlag(l0BEventList[l0BListId]); + l0BListId = (l0BListId + 1 < STAGES) ? (l0BListId + 1) : 0; + } + AscendC::SetFlag(l0AEventList[l0AListId]); + l0AListId = (l0AListId + 1 < STAGES) ? (l0AListId + 1) : 0; + } + } + l1ListId = l1ListIdNext; + kActual = kActualNext; + } + + // copy block out + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + copyL0CToGm(tensorC, tensorL0C); + AscendC::SetFlag(EVENT_ID0); + } else { + copyL0CToGm(tensorC, tensorL0C, 0b11); + } + } + +protected: + // Multi-stage tensors list + AscendC::LocalTensor l1ATensorList[STAGES]; + AscendC::LocalTensor l1BTensorList[STAGES]; + AscendC::LocalTensor l0ATensorList[STAGES]; + AscendC::LocalTensor l0BTensorList[STAGES]; + AscendC::LocalTensor l0CTensor; + + // Multi-stage event id list + int32_t l1AEventList[STAGES]; + int32_t l1BEventList[STAGES]; + int32_t l0AEventList[STAGES]; + int32_t l0BEventList[STAGES]; + + // The id of current stage + uint32_t l1ListId{0}; + uint32_t l0AListId{0}; + uint32_t l0BListId{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_TLA_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_w8a16.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_w8a16.hpp new file mode 100644 index 00000000..918ae75a --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_mmad_w8a16.hpp @@ -0,0 +1,620 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_MMAD_W8A16_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_MMAD_W8A16_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" + +namespace Catlass::Gemm::Block { + +template +struct PrologueCast { + using ElementIn = ElementIn_; + using ElementOut = ElementOut_; + using TileShape = TileShape_; + using Layout = Layout_; + + static constexpr uint32_t ELE_NUM_PER_BLK_INT8 = BYTE_PER_BLK / sizeof(ElementIn); + static constexpr uint32_t ELE_NUM_PER_BLK_HALF = BYTE_PER_BLK / sizeof(ElementOut); + static constexpr uint32_t COMPUTE_LEN = 32 * 1024; + static constexpr uint32_t TILES_PER_LOOP = 32; + + // Construct + CATLASS_DEVICE + PrologueCast(Arch::Resource &resource, uint32_t ubBufAddrStart = 0) + { + if (g_coreType == AscendC::AIV) { + uint32_t ubOffset = ubBufAddrStart; + uint32_t ubInSize = COMPUTE_LEN * sizeof(ElementIn); + uint32_t ubOutSize = COMPUTE_LEN * sizeof(ElementOut); + // Init buffers + for (uint32_t i = 0; i < STAGES; ++i) { + ubInTensorList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += ubInSize; + ubOutTensorList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += ubOutSize; + + ubEventList[i] = i; + AscendC::SetFlag(ubEventList[i]); + AscendC::SetFlag(ubEventList[i]); + } + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &gmDst, AscendC::GlobalTensor const &gmSrc, + Layout const &layoutDst, Layout const &layoutSrc, half deqScalar, half deqZeroPoint) + { + uint32_t tileNum = layoutSrc.shape(0); + uint32_t tileLen = layoutSrc.shape(1); + uint32_t tileLenRoundInt8 = RoundUp(layoutSrc.shape(1), ELE_NUM_PER_BLK_INT8); + uint64_t tileStrideSrc = layoutSrc.stride(0); + uint64_t tileStrideDst = layoutDst.stride(0); + if constexpr (std::is_same_v) { + tileNum = layoutSrc.shape(1); + tileLen = layoutSrc.shape(0); + tileLenRoundInt8 = RoundUp(layoutSrc.shape(0), ELE_NUM_PER_BLK_INT8); + tileStrideSrc = layoutSrc.stride(1); + tileStrideDst = layoutDst.stride(1); + } + uint32_t tilesPerAiv = tileNum / AscendC::GetSubBlockNum(); + if (AscendC::GetSubBlockIdx() < (tileNum % AscendC::GetSubBlockNum())) { + tilesPerAiv++; + } + uint64_t taskOffsetSrc = AscendC::GetSubBlockIdx() * tilesPerAiv * tileStrideSrc; + uint64_t taskOffsetDst = AscendC::GetSubBlockIdx() * tilesPerAiv * tileStrideDst; + if (AscendC::GetSubBlockIdx() >= (tileNum % AscendC::GetSubBlockNum())) { + taskOffsetSrc += (tileNum % AscendC::GetSubBlockNum()) * tileStrideSrc; + taskOffsetDst += (tileNum % AscendC::GetSubBlockNum()) * tileStrideDst; + } + uint32_t loops = CeilDiv(tilesPerAiv, TILES_PER_LOOP); + uint32_t pingpong = 0; + for (uint32_t loopIdx = 0; loopIdx < loops; ++loopIdx) { + uint32_t actualTiles = TILES_PER_LOOP; + if (loopIdx == loops - 1) { + actualTiles = tilesPerAiv - loopIdx * TILES_PER_LOOP; + } + uint64_t tileOffsetSrc = loopIdx * TILES_PER_LOOP * tileStrideSrc; + AscendC::DataCopyExtParams dataCopyParamsIn(actualTiles, tileLen * sizeof(ElementIn), + (tileStrideSrc - tileLen) * sizeof(ElementIn), + (tileLenRoundInt8 - tileLen) / ELE_NUM_PER_BLK_INT8, 0); + AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); + + AscendC::WaitFlag(ubEventList[pingpong]); + AscendC::DataCopyPad(ubInTensorList[pingpong], gmSrc[taskOffsetSrc + tileOffsetSrc], dataCopyParamsIn, + padParams); + + AscendC::SetFlag(ubEventList[pingpong]); + AscendC::WaitFlag(ubEventList[pingpong]); + + AscendC::WaitFlag(ubEventList[pingpong]); + + AscendC::Cast(ubOutTensorList[pingpong], ubInTensorList[pingpong], AscendC::RoundMode::CAST_NONE, + actualTiles * tileLenRoundInt8); + AscendC::PipeBarrier(); + AscendC::SetFlag(ubEventList[pingpong]); + + AscendC::Adds(ubOutTensorList[pingpong], ubOutTensorList[pingpong], deqZeroPoint, + actualTiles * tileLenRoundInt8); + AscendC::PipeBarrier(); + + AscendC::Muls(ubOutTensorList[pingpong], ubOutTensorList[pingpong], deqScalar, + actualTiles * tileLenRoundInt8); + AscendC::PipeBarrier(); + + AscendC::SetFlag(ubEventList[pingpong]); + AscendC::WaitFlag(ubEventList[pingpong]); + + uint64_t tileOffsetDst = loopIdx * TILES_PER_LOOP * tileStrideDst; + AscendC::DataCopyExtParams dataCopyParamsOut(actualTiles, tileLen * sizeof(ElementOut), + (tileLenRoundInt8 - tileLen) / ELE_NUM_PER_BLK_HALF, + (tileStrideDst - tileLen) * sizeof(ElementOut), 0); + AscendC::DataCopyPad(gmDst[taskOffsetDst + tileOffsetDst], ubOutTensorList[pingpong], dataCopyParamsOut); + AscendC::SetFlag(ubEventList[pingpong]); + + pingpong = (pingpong + 1) % STAGES; + } + } + + /// Destructor + CATLASS_DEVICE + ~PrologueCast() + { + if (g_coreType == AscendC::AIV) { + for (uint32_t i = 0; i < STAGES; ++i) { + AscendC::WaitFlag(ubEventList[i]); + AscendC::WaitFlag(ubEventList[i]); + } + } + } + +protected: + /// Data members + AscendC::LocalTensor ubInTensorList[STAGES]; + AscendC::LocalTensor ubOutTensorList[STAGES]; + + int32_t ubEventList[STAGES]; +}; + +template +struct BlockMmad, L1TileShape_, L0TileShape_, AType_, BType_, + CType_, BiasType_, TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = MmadAtlasA2Preload; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + using TileShapeB = MatrixShape; + using PrologueCastB = PrologueCast; // no use of TileShapeB + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K; + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; + static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; + static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; + static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; + static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES; + + // Check LayoutA + static_assert(std::is_same_v || std::is_same_v, + "LayoutA only support RowMajor/ColumnMajor yet!"); + + // Check LayoutB + static_assert(std::is_same_v || std::is_same_v, + "LayoutB only support RowMajor/ColumnMajor yet!"); + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + // Check L1TileShape + static_assert((L1A_SIZE * STAGES + L1B_SIZE * STAGES) <= ArchTag::L1_SIZE, "L1TileShape exceeding the L1 space!"); + + // Check L0TileShape + static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA); + static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB); + static constexpr uint32_t L0C_TILE_SIZE = L0TileShape::M * L0TileShape::N * sizeof(ElementAccumulator); + static_assert((L0A_TILE_SIZE * STAGES) <= L0A_SIZE, "L0TileShape exceeding the L0A space!"); + static_assert((L0B_TILE_SIZE * STAGES) <= L0B_SIZE, "L0TileShape exceeding the L0B space!"); + static_assert(L0C_TILE_SIZE <= L0C_SIZE, "L0TileShape exceeding the L0C space!"); + + static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N, + "The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet"); + static_assert(L0TileShape::K <= L1TileShape::K, "L0TileShape::K cannot exceed L1TileShape::K"); + + /// Construct + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) : prologueCastB(resource) + { + if (g_coreType == AscendC::AIC) { + uint32_t l1AOffset = l1BufAddrStart; + uint32_t l1BOffset = l1BufAddrStart + L1A_SIZE * STAGES; + // Init buffers + for (uint32_t i = 0; i < STAGES; i++) { + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1A_SIZE * i); + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_SIZE * i); + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_PINGPONG_BUF_SIZE * i); + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_PINGPONG_BUF_SIZE * i); + + l1AEventList[i] = i; + l1BEventList[i] = i + STAGES; + l0AEventList[i] = i; + l0BEventList[i] = i + STAGES; + AscendC::SetFlag(l1AEventList[i]); + AscendC::SetFlag(l1BEventList[i]); + AscendC::SetFlag(l0AEventList[i]); + AscendC::SetFlag(l0BEventList[i]); + } + l0CTensor = resource.l0CBuf.template GetBufferByByte(0); + AscendC::SetFlag(EVENT_ID0); + + Arch::CrossCoreSetFlag<0x2, PIPE_MTE2>(notifyAiv[0]); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE2>(notifyAiv[1]); + } + } + + /// Destructor + CATLASS_DEVICE + ~BlockMmad() + { + if (g_coreType == AscendC::AIC) { + for (uint32_t i = 0; i < STAGES; i++) { + AscendC::WaitFlag(l1AEventList[i]); + AscendC::WaitFlag(l1BEventList[i]); + AscendC::WaitFlag(l0AEventList[i]); + AscendC::WaitFlag(l0BEventList[i]); + } + AscendC::WaitFlag(EVENT_ID0); + } else { + Arch::CrossCoreWaitFlag(notifyAiv[0]); + Arch::CrossCoreWaitFlag(notifyAiv[1]); + } + } + + /// Prologue: cast int8_t to half (w8a16) + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &gmBlockB, LayoutB const &layoutB, + AscendC::GlobalTensor const &gmNextBlockB, AscendC::GlobalTensor const &gmBWksp, + GemmCoord const &actualShape, GemmCoord const &actualShapeNext, bool isFirstBlock, + bool hasNextBlock, half deqScalar, half deqZeroPoint) + { + uint32_t kTileCount = CeilDiv(actualShape.k()); + uint32_t kTileCountNext = CeilDiv(actualShapeNext.k()); + + uint32_t wkspStrideB = L1TileShape::N; + if (std::is_same_v) { + wkspStrideB = L1TileShape::K; + } + + uint32_t startTileIdx = 0; + if constexpr (ENABLE_SHUFFLE_K) { + startTileIdx = AscendC::GetBlockIdx() / 2; + } + uint32_t firstTileIdx = startTileIdx % kTileCount; + uint32_t lastTileIdx = (startTileIdx + kTileCount - 1) % kTileCount; + uint32_t kActual = + (firstTileIdx < kTileCount - 1) ? L1TileShape::K : (actualShape.k() - firstTileIdx * L1TileShape::K); + uint32_t firstTileIdxNext = startTileIdx % kTileCountNext; + + // k loop + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; kLoopIdx++) { + uint32_t shuffleKIdx = (startTileIdx + kLoopIdx) % kTileCount; + // Load first matrix B tile in total kernel loop from GM to UB + if (shuffleKIdx == firstTileIdx && isFirstBlock) { + MatrixCoord gmTileBOffset{shuffleKIdx * L1TileShape::K, 0}; + auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)]; + // Load first matrix B tile from GM to UB + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); + auto layoutWkspB = LayoutB{kActual, actualShape.n(), wkspStrideB}; + + Arch::CrossCoreWaitFlag(notifyAiv[l1ListId]); + prologueCastB(gmBWksp[l1ListId * L1TileShape::K * L1TileShape::N], gmTileB, layoutWkspB, layoutTileB, + deqScalar, deqZeroPoint); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(notifyAic[l1ListId]); + } + + uint32_t l1ListIdNext = (l1ListId + 1 < STAGES) ? (l1ListId + 1) : 0; + uint32_t kActualNext{0}; + + // preload next tile from GM to UB + if (shuffleKIdx != lastTileIdx) { + uint32_t shuffleKIdxNext = (startTileIdx + kLoopIdx + 1) % kTileCount; + // Get GM tensor for next stage + kActualNext = (shuffleKIdxNext < kTileCount - 1) ? L1TileShape::K + : (actualShape.k() - shuffleKIdxNext * L1TileShape::K); + MatrixCoord gmTileBOffset{shuffleKIdxNext * L1TileShape::K, 0}; + auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)]; + // load next matrix B tile from GM to UB + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActualNext, actualShape.n())); + auto layoutWkspB = LayoutB{kActualNext, actualShape.n(), wkspStrideB}; + + Arch::CrossCoreWaitFlag(notifyAiv[l1ListIdNext]); + prologueCastB(gmBWksp[l1ListIdNext * L1TileShape::K * L1TileShape::N], gmTileB, layoutWkspB, + layoutTileB, deqScalar, deqZeroPoint); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(notifyAic[l1ListIdNext]); + } + if (shuffleKIdx == lastTileIdx && hasNextBlock) { + // Get GM tensor for next stage + kActualNext = (firstTileIdxNext < kTileCountNext - 1) + ? L1TileShape::K + : (actualShapeNext.k() - firstTileIdxNext * L1TileShape::K); + MatrixCoord gmTileBOffset{firstTileIdxNext * L1TileShape::K, 0}; + auto gmTileB = gmNextBlockB[layoutB.GetOffset(gmTileBOffset)]; + // load next matrix B tile from GM to UB + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActualNext, actualShapeNext.n())); + auto layoutWkspB = LayoutB{kActualNext, actualShapeNext.n(), wkspStrideB}; + + Arch::CrossCoreWaitFlag(notifyAiv[l1ListIdNext]); + prologueCastB(gmBWksp[l1ListIdNext * L1TileShape::K * L1TileShape::N], gmTileB, layoutWkspB, + layoutTileB, deqScalar, deqZeroPoint); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(notifyAic[l1ListIdNext]); + } + l1ListId = l1ListIdNext; + kActual = kActualNext; + } + } + + /// Perform a block-scoped matrix multiply-accumulate + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &gmBlockA, LayoutA const &layoutA, + AscendC::GlobalTensor const &gmBlockB, AscendC::GlobalTensor const &gmBlockC, + LayoutC const &layoutC, AscendC::GlobalTensor const &gmNextBlockA, + GemmCoord const &actualShape, GemmCoord const &actualShapeNext, bool isFirstBlock, + bool hasNextBlock) + { + uint32_t mRound = RoundUp(actualShape.m()); + uint32_t nRound = RoundUp(actualShape.n()); + + auto layoutAInL1 = LayoutAInL1::template MakeLayout(L1TileShape::M, L1TileShape::K); + auto layoutBInL1 = LayoutBInL1::template MakeLayout(L1TileShape::K, L1TileShape::N); + auto layoutInL0C = LayoutCInL0::MakeLayoutInL0C(MakeCoord(mRound, nRound)); + + uint32_t kTileCount = CeilDiv(actualShape.k()); + uint32_t kTileCountNext = CeilDiv(actualShapeNext.k()); + + uint32_t wkspStrideB = L1TileShape::N; + if (std::is_same_v) { + wkspStrideB = L1TileShape::K; + } + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::WaitFlag(EVENT_ID0); + } + uint32_t startTileIdx = 0; + if constexpr (ENABLE_SHUFFLE_K) { + startTileIdx = AscendC::GetBlockIdx(); + } + uint32_t firstTileIdx = startTileIdx % kTileCount; + uint32_t lastTileIdx = (startTileIdx + kTileCount - 1) % kTileCount; + uint32_t kActual = + (firstTileIdx < kTileCount - 1) ? L1TileShape::K : (actualShape.k() - firstTileIdx * L1TileShape::K); + uint32_t firstTileIdxNext = startTileIdx % kTileCountNext; + + uint32_t mPartLoop = CeilDiv(mRound); + uint32_t nPartLoop = CeilDiv(nRound); + + // k loop + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; kLoopIdx++) { + uint32_t shuffleKIdx = (startTileIdx + kLoopIdx) % kTileCount; + // Load first matrix A tile in total kernel loop from GM to L1 + if (shuffleKIdx == firstTileIdx && isFirstBlock) { + MatrixCoord gmTileAOffset{0, shuffleKIdx * L1TileShape::K}; + auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)]; + // Load first matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListId]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); + copyGmToL1A(l1ATensorList[l1ListId], gmTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListId]); + + // Load first matrix B tile from GM to L1 + Arch::CrossCoreWaitFlag(notifyAic[l1ListId]); + AscendC::WaitFlag(l1BEventList[l1ListId]); + auto layoutTileB = LayoutB{kActual, actualShape.n(), wkspStrideB}; + copyGmToL1B(l1BTensorList[l1ListId], gmBlockB[l1ListId * L1TileShape::K * L1TileShape::N], layoutBInL1, + layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListId]); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE2>(notifyAiv[l1ListId]); + } + + uint32_t l1ListIdNext = (l1ListId + 1 < STAGES) ? (l1ListId + 1) : 0; + uint32_t kActualNext{0}; + + // preload next tile from GM to L1 + if (shuffleKIdx != lastTileIdx) { + uint32_t shuffleKIdxNext = (startTileIdx + kLoopIdx + 1) % kTileCount; + // Get L1 tensor for next stage + auto l1ATensor = l1ATensorList[l1ListIdNext]; + auto l1BTensor = l1BTensorList[l1ListIdNext]; + // Get GM tensor for next stage + kActualNext = (shuffleKIdxNext < kTileCount - 1) ? L1TileShape::K + : (actualShape.k() - shuffleKIdxNext * L1TileShape::K); + MatrixCoord gmTileAOffset{0, shuffleKIdxNext * L1TileShape::K}; + auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)]; + + // load next matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActualNext)); + copyGmToL1A(l1ATensor, gmTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + + // load next matrix B tile from GM to L1 + Arch::CrossCoreWaitFlag(notifyAic[l1ListIdNext]); + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + auto layoutTileB = LayoutB{kActualNext, actualShape.n(), wkspStrideB}; + copyGmToL1B(l1BTensor, gmBlockB[l1ListIdNext * L1TileShape::K * L1TileShape::N], layoutBInL1, + layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE2>(notifyAiv[l1ListIdNext]); + } + if (shuffleKIdx == lastTileIdx && hasNextBlock) { + // Get L1 tensor for next stage + auto l1ATensor = l1ATensorList[l1ListIdNext]; + auto l1BTensor = l1BTensorList[l1ListIdNext]; + // Get GM tensor for next stage + kActualNext = (firstTileIdxNext < kTileCountNext - 1) + ? L1TileShape::K + : (actualShapeNext.k() - firstTileIdxNext * L1TileShape::K); + MatrixCoord gmTileAOffset{0, firstTileIdxNext * L1TileShape::K}; + auto gmTileA = gmNextBlockA[layoutA.GetOffset(gmTileAOffset)]; + // load next matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShapeNext.m(), kActualNext)); + copyGmToL1A(l1ATensor, gmTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + + // load next matrix B tile from GM to L1 + Arch::CrossCoreWaitFlag(notifyAic[l1ListIdNext]); + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + auto layoutTileB = LayoutB{kActualNext, actualShapeNext.n(), wkspStrideB}; + copyGmToL1B(l1BTensor, gmBlockB[l1ListIdNext * L1TileShape::K * L1TileShape::N], layoutBInL1, + layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE2>(notifyAiv[l1ListIdNext]); + } + + // Get L1 tensor for current stage + auto l1ATensor = l1ATensorList[l1ListId]; + auto l1BTensor = l1BTensorList[l1ListId]; + + // Get the loop nums on L0 + uint32_t kPartLoop = CeilDiv(kActual); + + uint32_t l0ABufId = 0; + uint32_t l0BBufId = 0; + + for (int mPartIdx = 0; mPartIdx < mPartLoop; mPartIdx++) { + uint32_t mPartActual = + (mPartIdx < mPartLoop - 1) ? L0TileShape::M : (mRound - mPartIdx * L0TileShape::M); + + for (int kPartIdx = 0; kPartIdx < kPartLoop; kPartIdx++) { + uint32_t kPartActual = + (kPartIdx < kPartLoop - 1) ? L0TileShape::K : (kActual - kPartIdx * L0TileShape::K); + + // Locate the current tile on L0A + auto l0ATile = l0ATensorList[l0ABufId]; + LayoutAInL0 layoutAInL0 = LayoutAInL0::template MakeLayout(mPartActual, kPartActual); + // Locate the current tile of matrix A on L1 + MatrixCoord l1AOffset{mPartIdx * L0TileShape::M, kPartIdx * L0TileShape::K}; + auto l1ATile = l1ATensor[layoutAInL1.GetOffset(l1AOffset)]; + + AscendC::WaitFlag(l0AEventList[l0ABufId]); + if (mPartIdx == 0 && kPartIdx == 0) { + AscendC::WaitFlag(l1AEventList[l1ListId]); + } + + copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, layoutAInL1); + + if (mPartIdx == mPartLoop - 1 && kPartIdx == kPartLoop - 1) { + AscendC::SetFlag(l1AEventList[l1ListId]); + } + + for (int nPartIdx = 0; nPartIdx < nPartLoop; nPartIdx++) { + uint32_t nPartActual = + (nPartIdx < nPartLoop - 1) ? L0TileShape::N : (nRound - nPartIdx * L0TileShape::N); + + // Locate the current tile on L0B + auto l0BTile = l0BTensorList[l0BBufId]; + LayoutBInL0 layoutBInL0 = LayoutBInL0::template MakeLayout(kPartActual, nPartActual); + // Locate the current tile of matrix B on L1 + MatrixCoord l1BOffset{kPartIdx * L0TileShape::K, nPartIdx * L0TileShape::N}; + auto l1BTile = l1BTensor[layoutBInL1.GetOffset(l1BOffset)]; + + // Wait for mmad finished + AscendC::WaitFlag(l0BEventList[l0BBufId]); + // If the current tile is the first one on the k&n axis, wait for loading matrix B from GM to L1 + if ((kPartIdx == 0) && (nPartIdx == 0)) { + AscendC::WaitFlag(l1BEventList[l1ListId]); + } + + // Load current tile from L1 to L0B + copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, layoutBInL1); + + // If the current tile is the last one on the k&n axis, notify to load matrix B from GM to L1 + if (kPartIdx == kPartLoop - 1 && nPartIdx == nPartLoop - 1) { + AscendC::SetFlag(l1BEventList[l1ListId]); + } + // Notify to do mmad + AscendC::SetFlag(EVENT_ID0); + + // Locate the current tile on L0C + MatrixCoord l0COffset{mPartIdx * L0TileShape::M, nPartIdx * L0TileShape::N}; + auto l0CTile = l0CTensor[layoutInL0C.GetOffset(l0COffset)]; + + // Compute the matrix multiplication on L0A and L0B and write the result to the accumulator + // Wait for loading L0B + AscendC::WaitFlag(EVENT_ID0); + + // If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0 + bool initC = ((kLoopIdx == 0) && (kPartIdx == 0)); + // If the unit flag is enabled, the unit flag is set according to the calculation progress + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if ((kLoopIdx == kTileCount - 1) && (mPartIdx == mPartLoop - 1) && + (kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + // Perform calculation operations + tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); + + // Notify to move the next L0B tile + AscendC::SetFlag(l0BEventList[l0BBufId]); + + l0BBufId = (l0BBufId + 1 < STAGES) ? (l0BBufId + 1) : 0; + } + AscendC::SetFlag(l0AEventList[l0ABufId]); + l0ABufId = (l0ABufId + 1 < STAGES) ? (l0ABufId + 1) : 0; + } + } + + l1ListId = l1ListIdNext; + kActual = kActualNext; + } + + // copy block out + LayoutC layoutBlock = layoutC.GetTileLayout(actualShape.GetCoordMN()); + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + copyL0CToGm(gmBlockC, l0CTensor, layoutBlock, layoutInL0C); + AscendC::SetFlag(EVENT_ID0); + } else { + copyL0CToGm(gmBlockC, l0CTensor, layoutBlock, layoutInL0C, 0b11); + } + } + +protected: + /// Data members + AscendC::LocalTensor l1ATensorList[STAGES]; + AscendC::LocalTensor l1BTensorList[STAGES]; + AscendC::LocalTensor l0ATensorList[STAGES]; + AscendC::LocalTensor l0BTensorList[STAGES]; + AscendC::LocalTensor l0CTensor; + + int32_t l1AEventList[STAGES]; + int32_t l1BEventList[STAGES]; + int32_t l0AEventList[STAGES]; + int32_t l0BEventList[STAGES]; + + Arch::CrossCoreFlag notifyAic[STAGES] = {EVENT_ID0, EVENT_ID1}; + Arch::CrossCoreFlag notifyAiv[STAGES] = {EVENT_ID2, EVENT_ID3}; + + uint32_t l1ListId{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; + PrologueCastB prologueCastB; +}; + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_BLOCK_MMAD_W8A16_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_swizzle.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_swizzle.hpp new file mode 100644 index 00000000..18239cf6 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/block/block_swizzle.hpp @@ -0,0 +1,374 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_BLOCK_BLOCK_SWIZZLE_HPP +#define CATLASS_GEMM_BLOCK_BLOCK_SWIZZLE_HPP + +#include "catlass/catlass.hpp" +#include "catlass/detail/alignment.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemm::Block { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Block swizzling function for Gemms +template +struct GemmIdentityBlockSwizzle { + /// Data members + + GemmCoord problemShape; + MatrixCoord tileMN; + MatrixCoord loopsMN; + + /// Methods + + CATLASS_DEVICE + GemmIdentityBlockSwizzle() {} + + CATLASS_DEVICE + GemmIdentityBlockSwizzle(GemmCoord const &problemShape_, MatrixCoord const &tileMN_) + : problemShape(problemShape_), tileMN(tileMN_) + { + loopsMN = CeilDiv(MatrixCoord(problemShape.GetCoordMN()), tileMN); + } + + CATLASS_DEVICE + GemmIdentityBlockSwizzle(GemmCoord const &problemShape_, MatrixCoord const &tileMN_, MatrixCoord const &loopsMN_) + : problemShape(problemShape_), tileMN(tileMN_), loopsMN(loopsMN_) + {} + + CATLASS_DEVICE + void Update(GemmCoord const &problemShape_, MatrixCoord const &tileMN_) + { + problemShape = problemShape_; + tileMN = tileMN_; + + loopsMN = CeilDiv(MatrixCoord(problemShape.GetCoordMN()), tileMN); + } + + CATLASS_DEVICE + void Update(GemmCoord const &problemShape_, MatrixCoord const &tileMN_, MatrixCoord const &loopsMN_) + { + problemShape = problemShape_; + tileMN = tileMN_; + loopsMN = loopsMN_; + } + + CATLASS_DEVICE + uint32_t GetCoreLoops() const + { + return loopsMN.row() * loopsMN.column(); + } + + CATLASS_DEVICE + uint32_t GetBatchIdx(uint32_t taskIdx) + { + return taskIdx / (GetCoreLoops()); + } + + CATLASS_DEVICE + GemmCoord GetBlockCoord(uint32_t taskIdx) + { + uint32_t innerIdx = taskIdx % GetCoreLoops(); + if constexpr (SwizzleDirection == 0) { // Zn + uint32_t tileBlockLoop = CeilDiv(loopsMN.row(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMN.column()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMN.column()); + + uint32_t nRow = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nRow = loopsMN.row() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nRow; + uint32_t nIdx = inTileBlockIdx / nRow; + if (tileBlockIdx % 2 == 1) { + nIdx = loopsMN.column() - nIdx - 1; + } + return GemmCoord{mIdx, nIdx, 0}; + } else if constexpr (SwizzleDirection == 1) { // Nz + uint32_t tileBlockLoop = CeilDiv(loopsMN.column(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMN.row()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMN.row()); + + uint32_t nCol = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nCol = loopsMN.column() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = inTileBlockIdx / nCol; + uint32_t nIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nCol; + if (tileBlockIdx % 2 == 1) { + mIdx = loopsMN.row() - mIdx - 1; + } + return GemmCoord{mIdx, nIdx, 0}; + } + } + + CATLASS_DEVICE + GemmCoord GetActualBlockShape(GemmCoord blockCoord) + { + uint32_t mActual = + (blockCoord.m() == (loopsMN.row() - 1)) ? (problemShape.m() - blockCoord.m() * tileMN.row()) : tileMN.row(); + uint32_t nActual = (blockCoord.n() == (loopsMN.column() - 1)) + ? (problemShape.n() - blockCoord.n() * tileMN.column()) + : tileMN.column(); + uint32_t kActual = problemShape.k(); + return GemmCoord{mActual, nActual, kActual}; + } +}; + +/// Block swizzling function for Splitk Gemms +template +struct SplitkGemmIdentityBlockSwizzle { + /// Data members + + GemmCoord problemShape; + GemmCoord tileShape; + GemmCoord loopsMNK; + uint32_t splitkFactor = 1; // split k dim into virtual cores + + /// Methods + + CATLASS_DEVICE + SplitkGemmIdentityBlockSwizzle() {} + + CATLASS_DEVICE + SplitkGemmIdentityBlockSwizzle(GemmCoord const &problemShape_, GemmCoord const &tileShape_, + uint32_t splitkFactor_ = 1) + : problemShape(problemShape_), tileShape(tileShape_), splitkFactor(splitkFactor_) + { + loopsMNK = CeilDiv(problemShape, tileShape); + } + + CATLASS_DEVICE + uint32_t GetKIdxBySplitkSliceIdx(uint32_t splitkSliceIdx) const + { + if (splitkSliceIdx < loopsMNK.k() % splitkFactor) { + return (loopsMNK.k() / splitkFactor + 1) * splitkSliceIdx; + } else { + return splitkSliceIdx * (loopsMNK.k() / splitkFactor) + loopsMNK.k() % splitkFactor; + } + } + + CATLASS_DEVICE + uint32_t GetSplitkSliceIdx(uint32_t taskIdx) const + { + uint32_t mnLoops = loopsMNK.m() * loopsMNK.n(); + return taskIdx % GetCoreLoops() / mnLoops; + } + + CATLASS_DEVICE + uint32_t GetCoreLoops() const + { + return loopsMNK.m() * loopsMNK.n() * splitkFactor; + } + + CATLASS_DEVICE + uint32_t GetBatchIdx(uint32_t taskIdx) + { + return taskIdx / GetCoreLoops(); + } + + CATLASS_DEVICE + GemmCoord GetBlockCoord(uint32_t taskIdx) + { + uint32_t splitkSliceIdx = GetSplitkSliceIdx(taskIdx); + uint32_t kIdx = GetKIdxBySplitkSliceIdx(splitkSliceIdx); + + uint32_t innerIdx = taskIdx % (loopsMNK.m() * loopsMNK.n()); + if constexpr (SwizzleDirection == 0) { // Zn + uint32_t tileBlockLoop = CeilDiv(loopsMNK.m(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMNK.n()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMNK.n()); + + uint32_t nRow = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nRow = loopsMNK.m() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nRow; + uint32_t nIdx = inTileBlockIdx / nRow; + if (tileBlockIdx % 2 == 1) { + nIdx = loopsMNK.n() - nIdx - 1; + } + return GemmCoord{mIdx, nIdx, kIdx}; + } else if constexpr (SwizzleDirection == 1) { // Nz + uint32_t tileBlockLoop = CeilDiv(loopsMNK.n(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMNK.m()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMNK.m()); + + uint32_t nCol = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nCol = loopsMNK.n() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = inTileBlockIdx / nCol; + uint32_t nIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nCol; + if (tileBlockIdx % 2 == 1) { + mIdx = loopsMNK.m() - mIdx - 1; + } + return GemmCoord{mIdx, nIdx, kIdx}; + } + } + + CATLASS_DEVICE + GemmCoord GetActualBlockShape(GemmCoord blockCoord, uint32_t splitkSliceIdx) + { + uint32_t splitkSliceLen; + if (splitkSliceIdx < loopsMNK.k() % splitkFactor) { + splitkSliceLen = (loopsMNK.k() / splitkFactor + 1) * tileShape.k(); + } else { + splitkSliceLen = (loopsMNK.k() / splitkFactor) * tileShape.k(); + } + uint32_t mActual = (blockCoord.m() == (loopsMNK.m() - 1)) ? (problemShape.m() - blockCoord.m() * tileShape.m()) + : tileShape.m(); + uint32_t nActual = (blockCoord.n() == (loopsMNK.n() - 1)) ? (problemShape.n() - blockCoord.n() * tileShape.n()) + : tileShape.n(); + uint32_t kActual = (splitkSliceIdx == (splitkFactor - 1)) ? (problemShape.k() - blockCoord.k() * tileShape.k()) + : splitkSliceLen; + return GemmCoord{mActual, nActual, kActual}; + } +}; + +/// Block swizzling function for Gemms +template +struct GemmIdentityBlockSwizzleL1FullLoad { + /// Data members + + GemmCoord problemShape; + MatrixCoord tileMN; + MatrixCoord loopsMN; + + uint32_t loopsPerCore; + uint32_t loopsTail; + uint32_t aicCoreNum; + + /// Methods + + CATLASS_DEVICE + GemmIdentityBlockSwizzleL1FullLoad() {} + + CATLASS_DEVICE + GemmIdentityBlockSwizzleL1FullLoad(GemmCoord const &problemShape_, MatrixCoord const &tileMN_) + : problemShape(problemShape_), tileMN(tileMN_) + { + loopsMN = CeilDiv(MatrixCoord(problemShape.GetCoordMN()), tileMN); + uint32_t loopsTotalNum = GetCoreLoops(); + aicCoreNum = AscendC::GetBlockNum(); + loopsPerCore = loopsTotalNum / aicCoreNum; + loopsTail = loopsTotalNum % aicCoreNum; + } + + CATLASS_DEVICE + GemmIdentityBlockSwizzleL1FullLoad(GemmCoord const &problemShape_, MatrixCoord const &tileMN_, + MatrixCoord const &loopsMN_) + : problemShape(problemShape_), tileMN(tileMN_), loopsMN(loopsMN_) + { + uint32_t loopsTotalNum = GetCoreLoops(); + aicCoreNum = AscendC::GetBlockNum(); + loopsPerCore = loopsTotalNum / aicCoreNum; + loopsTail = loopsTotalNum % aicCoreNum; + } + + CATLASS_DEVICE + void Update(GemmCoord const &problemShape_, MatrixCoord const &tileMN_) + { + problemShape = problemShape_; + tileMN = tileMN_; + loopsMN = CeilDiv(MatrixCoord(problemShape.GetCoordMN()), tileMN); + + uint32_t loopsTotalNum = GetCoreLoops(); + aicCoreNum = AscendC::GetBlockNum(); + loopsPerCore = loopsTotalNum / aicCoreNum; + loopsTail = loopsTotalNum % aicCoreNum; + } + + CATLASS_DEVICE + void Update(GemmCoord const &problemShape_, MatrixCoord const &tileMN_, MatrixCoord const &loopsMN_) + { + problemShape = problemShape_; + tileMN = tileMN_; + loopsMN = loopsMN_; + + uint32_t loopsTotalNum = GetCoreLoops(); + aicCoreNum = AscendC::GetBlockNum(); + loopsPerCore = loopsTotalNum / aicCoreNum; + loopsTail = loopsTotalNum % aicCoreNum; + } + + CATLASS_DEVICE + uint32_t GetCoreLoops() const + { + return loopsMN.row() * loopsMN.column(); + } + + ////// WARNING: current strategy not support GetBatchIdx() + + CATLASS_DEVICE + GemmCoord GetBlockCoord(uint32_t taskIdx) + { + // calculate innerIdx from taskIdx + uint32_t CoreIdx = taskIdx % aicCoreNum; + uint32_t innerCoreIdx = taskIdx / aicCoreNum; + uint32_t innerIdx = CoreIdx * loopsPerCore + innerCoreIdx; + if (CoreIdx < loopsTail) { + innerIdx += CoreIdx; + } else { + innerIdx += loopsTail; + } + // calculate block location in swizzle, using innerIdx + if constexpr (SwizzleDirection == 0) { // Zn + uint32_t tileBlockLoop = CeilDiv(loopsMN.row(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMN.column()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMN.column()); + + uint32_t nRow = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nRow = loopsMN.row() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nRow; + uint32_t nIdx = inTileBlockIdx / nRow; + if (tileBlockIdx % 2 == 1) { + nIdx = loopsMN.column() - nIdx - 1; + } + return GemmCoord{mIdx, nIdx, 0}; + } else if constexpr (SwizzleDirection == 1) { // Nz + uint32_t tileBlockLoop = CeilDiv(loopsMN.column(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMN.row()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMN.row()); + + uint32_t nCol = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nCol = loopsMN.column() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = inTileBlockIdx / nCol; + uint32_t nIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nCol; + if (tileBlockIdx % 2 == 1) { + mIdx = loopsMN.row() - mIdx - 1; + } + return GemmCoord{mIdx, nIdx, 0}; + } + } + + CATLASS_DEVICE + GemmCoord GetActualBlockShape(GemmCoord blockCoord) + { + uint32_t mActual = + (blockCoord.m() == (loopsMN.row() - 1)) ? (problemShape.m() - blockCoord.m() * tileMN.row()) : tileMN.row(); + uint32_t nActual = (blockCoord.n() == (loopsMN.column() - 1)) + ? (problemShape.n() - blockCoord.n() * tileMN.column()) + : tileMN.column(); + uint32_t kActual = problemShape.k(); + return GemmCoord{mActual, nActual, kActual}; + } +}; + +} // namespace Catlass::Gemm::Block + +#endif // CATLASS_GEMM_BLOCK_BLOCK_SWIZZLE_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/device/device_gemm.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/device/device_gemm.hpp new file mode 100644 index 00000000..7d205dff --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/device/device_gemm.hpp @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_DEVICE_DEVICE_GEMM_HPP +#define CATLASS_GEMM_DEVICE_DEVICE_GEMM_HPP + +#include +#include "catlass/catlass.hpp" +#include "catlass/status.hpp" +#include "catlass/gemm/device/kernel_adapter.hpp" + +#if defined(ENABLE_ASCENDC_DUMP) +#include "catlass/debug.hpp" +#endif + +namespace Catlass::Gemm::Device { + +template +class DeviceGemm +{ +public: + using Kernel = GemmKernel; + /// Argument structure: User API + using Arguments = typename GemmKernel::Arguments; + /// Argument structure: Kernel API + using Params = typename GemmKernel::Params; + +private: + /// kernel API parameters object + Params params_; + +public: + DeviceGemm() {} + ~DeviceGemm() {} + + /// Access the Params structure + Params const ¶ms() const + { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status CanImplement(Arguments const &args) + { + if (GemmKernel::CanImplement(args)) { + return Status::kSuccess; + } else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t GetWorkspaceSize(Arguments const &args) + { + size_t workspace_bytes = 0; + workspace_bytes += GemmKernel::GetWorkspaceSize(args); + return workspace_bytes; + } + + /// Initializes GEMM state from arguments + Status Initialize(Arguments const &args, uint8_t *workspace = nullptr, aclrtStream stream = nullptr) + { + // Initialize the Params structure + params_ = GemmKernel::ToUnderlyingArguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling matmul Kernel::to_underling arguments + inline Status Run(aclrtStream stream, uint32_t blockDim, uint64_t fftsAddr) + { +#if defined(ENABLE_ASCENDC_DUMP) + uint8_t *ptrDump{nullptr}; + aclCheck(aclrtMalloc(reinterpret_cast(&ptrDump), ALL_DUMPSIZE, ACL_MEM_MALLOC_HUGE_FIRST)); + if (fftsAddr == 0) { + Catlass::KernelAdapter<<>>(params_, ptrDump); + } else { + Catlass::KernelAdapter<<>>(params_, fftsAddr, ptrDump); + } + aclCheck(aclrtSynchronizeStream(stream)); + Adx::AdumpPrintWorkSpace(ptrDump, ALL_DUMPSIZE, stream, "device_gemm"); + aclCheck(aclrtFree(ptrDump)); +#else + if (fftsAddr == 0) { + Catlass::KernelAdapter<<>>(params_); + } else { + Catlass::KernelAdapter<<>>(params_, fftsAddr); + } +#endif + return Status::kSuccess; + } + + /// Runs the kernel using initialized state + inline Status operator()(aclrtStream stream, uint32_t blockDim) + { + return Run(stream, blockDim, 0); + } + + inline Status operator()(aclrtStream stream, uint32_t blockDim, uint64_t fftsAddr) + { + return Run(stream, blockDim, fftsAddr); + } +}; +/////////////////////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Gemm::Device +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/device/kernel_adapter.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/device/kernel_adapter.hpp new file mode 100644 index 00000000..be7a3d56 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/device/kernel_adapter.hpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef CATLASS_GEMM_DEVICE_KERNEL_ADAPTER_HPP +#define CATLASS_GEMM_DEVICE_KERNEL_ADAPTER_HPP + +#include "catlass/catlass.hpp" + +#if defined(ENABLE_ASCENDC_DUMP) +#include "catlass/debug.hpp" +#endif + +namespace Catlass { +/// Generic Catlass kernel template +template +CATLASS_GLOBAL void KernelAdapter(typename Operator::Params params, GM_ADDR ptrDump = nullptr) +{ + Operator op; +#if defined(ENABLE_ASCENDC_DUMP) + AscendC::InitDump(false, ptrDump, ALL_DUMPSIZE); +#endif + op(params); +} + +template +CATLASS_GLOBAL void KernelAdapter(typename Operator::Params params, uint64_t fftsAddr, GM_ADDR ptrDump = nullptr) +{ + AscendC::SetSyncBaseAddr(fftsAddr); + Operator op; +#if defined(ENABLE_ASCENDC_DUMP) + AscendC::InitDump(false, ptrDump, ALL_DUMPSIZE); +#endif + op(params); +} +} // namespace Catlass + +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/dispatch_policy.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/dispatch_policy.hpp new file mode 100644 index 00000000..30cc5956 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/dispatch_policy.hpp @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_DISPATCH_POLICY_HPP +#define CATLASS_GEMM_DISPATCH_POLICY_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" + +namespace Catlass::Gemm { + +// Block Mmad Policies + +template +struct MmadAtlasA2Base { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t ASYNC = ASYNC_; +}; + +using MmadAtlasA2 = MmadAtlasA2Base; +using MmadAtlasA2Async = MmadAtlasA2Base; + +// Now ENABLE_UNIT_FLAG_ must be false when input element is int8 +template +struct MmadAtlasA2Pingpong : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; +}; + +template +struct MmadAtlasA2PingpongSliceK : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; +}; + +template +struct MmadAtlasA2Preload : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; + static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; +}; + +struct MmadAtlasA2FAQK : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; + +struct MmadAtlasA2FAPV : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; + +struct MmadAtlasA2MLAQK : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; + +struct MmadAtlasA2MLAPV : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; + +struct MmadAtlasA2MLAQKTp1Spec : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; + +struct MmadAtlasA2MLAPVTp1Spec : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; + +template +struct MmadAtlasA2PreloadAsync : public MmadAtlasA2Async { + static constexpr uint32_t PRELOAD_STAGES = PRELOAD_STAGES_; // Stages of emitting load instruction in advance + static constexpr uint32_t L1_STAGES = L1_STAGES_; + static constexpr uint32_t L0A_STAGES = L0A_STAGES_; + static constexpr uint32_t L0B_STAGES = L0B_STAGES_; + static constexpr uint32_t L0C_STAGES = L0C_STAGES_; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; + static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; +}; + +template +struct MmadAtlasA2PreloadAsyncWithCallback + : public MmadAtlasA2PreloadAsync {}; +//////////////////// +// new add +template +struct GemmAtlasA2 : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; + static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; + static constexpr bool ENABLE_ABBA = ENABLE_ABBA_; +}; + +struct GemvAtlasA2 : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; +//////////////////// + +template +struct MmadAtlasA2PingpongBias : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; +}; + +template +struct MmadAtlasA2FAIQK : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; + static constexpr bool PAGED_CACHE_FLAG = PAGED_CACHE_FLAG_; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; +}; + +template +struct MmadAtlasA2FAIPV : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; + static constexpr bool PAGED_CACHE_FLAG = PAGED_CACHE_FLAG_; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; +}; + +template +struct MmadAtlasA2FAITailQK : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; + static constexpr bool PAGED_CACHE_FLAG = PAGED_CACHE_FLAG_; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; +}; + +template +struct MmadAtlasA2FAITailPV : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; + static constexpr bool PAGED_CACHE_FLAG = PAGED_CACHE_FLAG_; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; +}; + +template +struct MmadAtlasA2FullLoadA : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; +}; + +template +struct MmadAtlasA2W8A16 : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; + static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; +}; +} // namespace Catlass::Gemm + +#endif // CATLASS_GEMM_DISPATCH_POLICY_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/gemm_type.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/gemm_type.hpp new file mode 100644 index 00000000..97f9a04f --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/gemm_type.hpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_GEMM_TYPE_HPP +#define CATLASS_GEMM_GEMM_TYPE_HPP + +#include "catlass/catlass.hpp" + +namespace Catlass::Gemm { + +//////////////////////////////////////////////////////////////////// + +template +struct GemmType { + using Element = Element_; + using Layout = Layout_; + static constexpr AscendC::TPosition POSITION = POSITION_; +}; + +} // namespace Catlass::Gemm + +#endif // CATLASS_GEMM_GEMM_TYPE_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/helper.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/helper.hpp new file mode 100644 index 00000000..56dba6d4 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/helper.hpp @@ -0,0 +1,285 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_HELPER_HPP +#define CATLASS_GEMM_HELPER_HPP + +#include "catlass/catlass.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/gemm/gemm_type.hpp" +#include "tla/layout.hpp" + +namespace Catlass::Gemm::helper { + +template +struct L1AlignHelper { + static_assert(DEPENDENT_FALSE, "Unsupported align helper, can not find the specialization."); +}; + +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; +}; + +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; +}; + +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; +}; + +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; +}; + +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; +}; + +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; +}; + +template +struct ElementAccumulatorSelector { + static_assert(DEPENDENT_FALSE, + "Unsupported element accumulator selector, can not find the specialization."); +}; + +template <> +struct ElementAccumulatorSelector { + using ElementAccumulator = float; +}; + +template <> +struct ElementAccumulatorSelector { + using ElementAccumulator = float; +}; + +template <> +struct ElementAccumulatorSelector { + using ElementAccumulator = int32_t; +}; + +template <> +struct ElementAccumulatorSelector { + using ElementAccumulator = float; +}; + +template +struct L1ATypeSelector { + static_assert(DEPENDENT_FALSE, "Unsupported layout selector, can not find the specialization."); +}; + +template +struct L1ATypeSelector> { + using L1AType = Gemm::GemmType; +}; + +template +struct L1ATypeSelector> { + using L1AType = Gemm::GemmType; +}; + +template +struct L1ATypeSelector> { + using L1AType = Gemm::GemmType; +}; + +template +struct L1ATypeSelector> { + using L1AType = Gemm::GemmType; +}; + +template +struct L1ATypeSelector> { + using L1AType = Gemm::GemmType; +}; + +template +struct L1BTypeSelector { + static_assert(DEPENDENT_FALSE, "Unsupported layout selector, can not find the specialization."); +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BiasTypeSelector { + static_assert(DEPENDENT_FALSE, "Unsupported layout selector, can not find the specialization."); +}; + +template +struct L1BiasTypeSelector { + using GMBiasType = void; + using L1BiasType = void; + using L0BiasType = void; +}; + +template +struct L1BiasTypeSelector, ElementAccumulator> { + using GMBiasType = Gemm::GemmType; + using L1BiasType = Gemm::GemmType; + using L0BiasType = Gemm::GemmType; +}; + +template +struct L1AlignHelperTla { + static_assert(DEPENDENT_FALSE, "Unsupported align helper tla, can not find the specialization."); +}; + +template +struct L1AlignHelperTla::value>> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; +}; + +template +struct L1AlignHelperTla::value>> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; +}; + +/////////////////////////////////////// +// new add +template <> +struct ElementAccumulatorSelector { + using ElementAccumulator = int32_t; +}; + +template +struct L1AndL0TypeSelectorGemm { + static_assert(DEPENDENT_FALSE, "Unsupported layout selector, can not find the specialization."); + static_assert(DEPENDENT_FALSE, "Unsupported layout selector, can not find the specialization."); +}; + +template +struct L1AndL0TypeSelectorGemm, Gemm::GemmType> { + using L1AType = Gemm::GemmType; + using L1BType = Gemm::GemmType; + using L0AType = Gemm::GemmType; + using L0BType = Gemm::GemmType; +}; + +template <> +struct L1AndL0TypeSelectorGemm, Gemm::GemmType> { + using L1AType = Gemm::GemmType; + using L1BType = Gemm::GemmType; + using L0AType = Gemm::GemmType; + using L0BType = Gemm::GemmType; +}; + +template +struct L1AndL0TypeSelectorGemm, + Gemm::GemmType> { + using L1AType = Gemm::GemmType; + using L1BType = Gemm::GemmType; + using L0AType = Gemm::GemmType; + using L0BType = Gemm::GemmType; +}; + +template <> +struct L1AndL0TypeSelectorGemm, + Gemm::GemmType> { + using L1AType = Gemm::GemmType; + using L1BType = Gemm::GemmType; + using L0AType = Gemm::GemmType; + using L0BType = Gemm::GemmType; +}; + +template +struct L1AndL0TypeSelectorGemm, + Gemm::GemmType> { + using L1AType = Gemm::GemmType; + using L1BType = Gemm::GemmType; + using L0AType = Gemm::GemmType; + using L0BType = Gemm::GemmType; +}; + +template +struct L1AndL0TypeSelectorGemm, + Gemm::GemmType> { + using L1AType = Gemm::GemmType; + using L1BType = Gemm::GemmType; + using L0AType = Gemm::GemmType; + using L0BType = Gemm::GemmType; +}; + +template <> +struct L1AndL0TypeSelectorGemm, Gemm::GemmType> { + using L1AType = Gemm::GemmType; + using L1BType = Gemm::GemmType; + using L0AType = Gemm::GemmType; + using L0BType = Gemm::GemmType; +}; +/////////////////////////////////////// +} // namespace Catlass::Gemm::helper + +#endif // CATLASS_GEMM_HELPER_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/basic_matmul.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/basic_matmul.hpp new file mode 100644 index 00000000..4a02d8f5 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/basic_matmul.hpp @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_MATMUL_HPP +#define CATLASS_GEMM_KERNEL_MATMUL_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemm::Kernel { + +// Template for Matmul kernel. Compute C = A * B +template +class BasicMatmul +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockScheduler = BlockScheduler_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + GM_ADDR ptrA; + LayoutA layoutA; + GM_ADDR ptrB; + LayoutB layoutB; + GM_ADDR ptrC; + LayoutC layoutC; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, + GM_ADDR ptrC_, LayoutC layoutC_) + : problemShape(problemShape_), + ptrA(ptrA_), + layoutA(layoutA_), + ptrB(ptrB_), + layoutB(layoutB_), + ptrC(ptrC_), + layoutC(layoutC_) + {} + }; + + struct Arguments { + GemmCoord problemShape; + GM_ADDR ptrA; + GM_ADDR ptrB; + GM_ADDR ptrC; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + return 0; + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + LayoutA layoutA{args.problemShape.m(), args.problemShape.k()}; + LayoutB layoutB{args.problemShape.k(), args.problemShape.n()}; + LayoutC layoutC{args.problemShape.m(), args.problemShape.n()}; + Params params{args.problemShape, args.ptrA, layoutA, args.ptrB, layoutB, args.ptrC, layoutC}; + return params; + } + + // Methods + CATLASS_DEVICE + BasicMatmul() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + /// Executes one Matmul + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler matmulBlockScheduler(params.problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops(); + + Arch::Resource resource; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrC); + + for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + // Compute block location + GemmCoord blockCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockCoord); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); + int64_t gmOffsetB = params.layoutB.GetOffset(offsetB); + int64_t gmOffsetC = params.layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + blockMmad(gmA[gmOffsetA], params.layoutA, gmB[gmOffsetB], params.layoutB, gmC[gmOffsetC], params.layoutC, + actualBlockShape); + } + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + {} +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_MATMUL_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/basic_matmul_preload.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/basic_matmul_preload.hpp new file mode 100644 index 00000000..b4cf4a3c --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/basic_matmul_preload.hpp @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_MATMUL_PRELOAD_HPP +#define CATLASS_GEMM_KERNEL_MATMUL_PRELOAD_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemm::Kernel { + +// Template for Matmul kernel. Compute C = A * B +template +class BasicMatmulPreload +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockScheduler = BlockScheduler_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + GM_ADDR ptrA; + LayoutA layoutA; + GM_ADDR ptrB; + LayoutB layoutB; + GM_ADDR ptrC; + LayoutC layoutC; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, + GM_ADDR ptrC_, LayoutC layoutC_) + : problemShape(problemShape_), + ptrA(ptrA_), + layoutA(layoutA_), + ptrB(ptrB_), + layoutB(layoutB_), + ptrC(ptrC_), + layoutC(layoutC_) + {} + }; + + struct Arguments { + GemmCoord problemShape; + GM_ADDR ptrA; + GM_ADDR ptrB; + GM_ADDR ptrC; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + return 0; + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + LayoutA layoutA{args.problemShape.m(), args.problemShape.k()}; + // LayoutB layoutB{args.problemShape.k(), args.problemShape.n()}; // for RowMajor & ColMajor layout + LayoutB layoutB = + layout::zN::MakeLayout(args.problemShape.k(), args.problemShape.n()); // for zN layout + LayoutC layoutC{args.problemShape.m(), args.problemShape.n()}; + Params params{args.problemShape, args.ptrA, layoutA, args.ptrB, layoutB, args.ptrC, layoutC}; + return params; + } + + // Methods + CATLASS_DEVICE + BasicMatmulPreload() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + /// Executes one Matmul + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler matmulBlockScheduler(params.problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops(); + + Arch::Resource resource; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrC); + + for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + // Compute block location + GemmCoord blockCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockCoord); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); + int64_t gmOffsetB = params.layoutB.GetOffset(offsetB); + int64_t gmOffsetC = params.layoutC.GetOffset(offsetC); + + bool isFirstBlock = (loopIdx == AscendC::GetBlockIdx()); + bool hasNextBlock = false; + GemmCoord nextBlockIdCoord; + GemmCoord nextActualBlockShape; + if (loopIdx + AscendC::GetBlockNum() < coreLoops) { + hasNextBlock = true; + nextBlockIdCoord = matmulBlockScheduler.GetBlockCoord(loopIdx + AscendC::GetBlockNum()); + nextActualBlockShape = matmulBlockScheduler.GetActualBlockShape(nextBlockIdCoord); + } + MatrixCoord offsetNextA{nextBlockIdCoord.m() * L1TileShape::M, nextBlockIdCoord.k() * L1TileShape::K}; + MatrixCoord offsetNextB{nextBlockIdCoord.k() * L1TileShape::K, nextBlockIdCoord.n() * L1TileShape::N}; + int64_t gmOffsetNextA = params.layoutA.GetOffset(offsetNextA); + int64_t gmOffsetNextB = params.layoutB.GetOffset(offsetNextB); + + // Compute block-scoped matrix multiply-add + blockMmad(gmA[gmOffsetA], params.layoutA, gmB[gmOffsetB], params.layoutB, gmC[gmOffsetC], params.layoutC, + gmA[gmOffsetNextA], gmB[gmOffsetNextB], actualBlockShape, nextActualBlockShape, isFirstBlock, + hasNextBlock); + } + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + {} +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_MATMUL_PRELOAD_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/basic_matmul_tla.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/basic_matmul_tla.hpp new file mode 100644 index 00000000..dfe8e59e --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/basic_matmul_tla.hpp @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_MATMUL_TLA_HPP +#define CATLASS_GEMM_KERNEL_MATMUL_TLA_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "tla/tensor.hpp" +#include "tla/layout.hpp" +#include "tla/tensor.hpp" + +namespace Catlass::Gemm::Kernel { + +// Template for Matmul kernel. Compute C = A * B +template +class BasicMatmulTla +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockScheduler = BlockScheduler_; + + static constexpr uint32_t L1_TILE_M = tla::get<0>(L1TileShape{}); + static constexpr uint32_t L1_TILE_N = tla::get<1>(L1TileShape{}); + static constexpr uint32_t L1_TILE_K = tla::get<2>(L1TileShape{}); + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + GM_ADDR ptrA; + LayoutA layoutA; + GM_ADDR ptrB; + LayoutB layoutB; + GM_ADDR ptrC; + LayoutC layoutC; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, + GM_ADDR ptrC_, LayoutC layoutC_) + : problemShape(problemShape_), + ptrA(ptrA_), + layoutA(layoutA_), + ptrB(ptrB_), + layoutB(layoutB_), + ptrC(ptrC_), + layoutC(layoutC_) + {} + }; + + struct Arguments { + GemmCoord problemShape; + uint8_t *ptrA; + LayoutA layoutA; + uint8_t *ptrB; + LayoutB layoutB; + uint8_t *ptrC; + LayoutC layoutC; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + return 0; + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + uint32_t m = args.problemShape.m(); + uint32_t n = args.problemShape.n(); + uint32_t k = args.problemShape.k(); + Params params{args.problemShape, args.ptrA, args.layoutA, args.ptrB, args.layoutB, args.ptrC, args.layoutC}; + return params; + } + + // Methods + CATLASS_DEVICE + BasicMatmulTla() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + /// Executes one Matmul + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler matmulBlockScheduler(params.problemShape, MakeCoord(L1_TILE_M, L1_TILE_N)); + uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops(); + + Arch::Resource resource; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrC); + + // Represent the full tensors + auto tensorA = tla::MakeTensor(gmA, params.layoutA, Arch::PositionGM{}); + auto tensorB = tla::MakeTensor(gmB, params.layoutB, Arch::PositionGM{}); + auto tensorC = tla::MakeTensor(gmC, params.layoutC, Arch::PositionGM{}); + + for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + // Compute block location + GemmCoord blockCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockCoord); + + // Make tiled views + auto tensorBlockA = GetTile(tensorA, tla::MakeCoord(blockCoord.m() * L1_TILE_M, blockCoord.k() * L1_TILE_K), + tla::MakeShape(actualBlockShape.m(), actualBlockShape.k())); + auto tensorBlockB = GetTile(tensorB, tla::MakeCoord(blockCoord.k() * L1_TILE_K, blockCoord.n() * L1_TILE_N), + tla::MakeShape(actualBlockShape.k(), actualBlockShape.n())); + auto tensorBlockC = GetTile(tensorC, tla::MakeCoord(blockCoord.m() * L1_TILE_M, blockCoord.n() * L1_TILE_N), + tla::MakeShape(actualBlockShape.m(), actualBlockShape.n())); + + // Compute block-scoped matrix multiply-add + blockMmad(tensorBlockA, tensorBlockB, tensorBlockC, actualBlockShape); + } + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + {} +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_MATMUL_TLA_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/batched_matmul.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/batched_matmul.hpp new file mode 100644 index 00000000..8fbda626 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/batched_matmul.hpp @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_BATCHED_MATMUL_HPP +#define CATLASS_GEMM_KERNEL_BATCHED_MATMUL_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemm::Kernel { + +// Template for Batched Matmul kernel. Compute batched C = A * B +template +class BatchedMatmul +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockScheduler = BlockScheduler_; + + /// Parameters structure + struct Params { + // Data members + uint32_t batchCount; + GemmCoord problemShape; + GM_ADDR ptrA; + LayoutA layoutA; + int64_t strideA; + GM_ADDR ptrB; + LayoutB layoutB; + int64_t strideB; + GM_ADDR ptrC; + LayoutC layoutC; + int64_t strideC; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(uint32_t batchCount_, GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, int64_t strideA_, + GM_ADDR ptrB_, LayoutB layoutB_, int64_t strideB_, GM_ADDR ptrC_, LayoutC layoutC_, int64_t strideC_) + : batchCount(batchCount_), + problemShape(problemShape_), + ptrA(ptrA_), + layoutA(layoutA_), + strideA(strideA_), + ptrB(ptrB_), + layoutB(layoutB_), + strideB(strideB_), + ptrC(ptrC_), + layoutC(layoutC_), + strideC(strideC_) + {} + }; + + struct Arguments { + uint32_t batchCount; + GemmCoord problemShape; + GM_ADDR ptrA; + GM_ADDR ptrB; + GM_ADDR ptrC; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + return 0; + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + GemmCoord problemShape = args.problemShape; + uint32_t m = problemShape.m(); + uint32_t n = problemShape.n(); + uint32_t k = problemShape.k(); + int64_t strideA = problemShape.m() * problemShape.k(); + int64_t strideB = problemShape.k() * problemShape.n(); + int64_t strideC = problemShape.m() * problemShape.n(); + LayoutA layoutA{args.problemShape.m(), args.problemShape.k()}; + LayoutB layoutB{args.problemShape.k(), args.problemShape.n()}; + LayoutC layoutC{args.problemShape.m(), args.problemShape.n()}; + Params params{args.batchCount, problemShape, args.ptrA, layoutA, strideA, args.ptrB, + layoutB, strideB, args.ptrC, layoutC, strideC}; + return params; + } + + // Methods + CATLASS_DEVICE + BatchedMatmul() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + /// Executes one GEMM + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler matmulBlockScheduler(params.problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = params.batchCount * matmulBlockScheduler.GetCoreLoops(); + + Arch::Resource resource; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrC); + + for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + // Compute block location + uint32_t batchIdx = matmulBlockScheduler.GetBatchIdx(loopIdx); + GemmCoord blockCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockCoord); + + // batchOffset + int64_t batchOffsetA = batchIdx * params.strideA; + int64_t batchOffsetB = batchIdx * params.strideB; + int64_t batchOffsetC = batchIdx * params.strideC; + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); + int64_t gmOffsetB = params.layoutB.GetOffset(offsetB); + int64_t gmOffsetC = params.layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + blockMmad(gmA[batchOffsetA + gmOffsetA], params.layoutA, gmB[batchOffsetB + gmOffsetB], params.layoutB, + gmC[batchOffsetC + gmOffsetC], params.layoutC, actualBlockShape); + } + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + {} +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_BATCHED_MATMUL_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/fp8_matmul.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/fp8_matmul.hpp new file mode 100644 index 00000000..c4ef8a92 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/fp8_matmul.hpp @@ -0,0 +1,528 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_FP8_MATMUL_HPP +#define CATLASS_GEMM_KERNEL_FP8_MATMUL_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/gemm/block/block_dequant.hpp" + +namespace Catlass::Gemm::Kernel { + +template +class FP8Matmul +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockScheduler = BlockScheduler_; + + static const uint32_t COMPUTE_LENGTH_A = 16 * 1024 / sizeof(int8_t); + using PrologueA = Block::DequantFP8toFP16; + static const uint32_t COMPUTE_LENGTH_B = 16 * 1024 / sizeof(int8_t); + using PrologueB = Block::DequantFP8toFP16; + + using Cast = Block::DequantFP8toFP16; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + GM_ADDR ptrA; + LayoutA layoutA; + GM_ADDR ptrB; + LayoutB layoutB; + GM_ADDR ptrC; + LayoutC layoutC; + GM_ADDR ptrWA; + GM_ADDR ptrWB; + GM_ADDR ptrWC; + half scalar; + half zeroPoint; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, + GM_ADDR ptrC_, LayoutC layoutC_, GM_ADDR ptrWA_, GM_ADDR ptrWB_, GM_ADDR ptrWC_, half scalar_, + half zeroPoint_) + : problemShape(problemShape_), + ptrA(ptrA_), + layoutA(layoutA_), + ptrB(ptrB_), + layoutB(layoutB_), + ptrC(ptrC_), + layoutC(layoutC_), + ptrWA(ptrWA_), + ptrWB(ptrWB_), + ptrWC(ptrWC_), + scalar(scalar_), + zeroPoint(zeroPoint_) + {} + }; + + struct Arguments { + GemmCoord problemShape; + GM_ADDR ptrA; + GM_ADDR ptrB; + GM_ADDR ptrC; + GM_ADDR ptrWA; + GM_ADDR ptrWB; + GM_ADDR ptrWC; + half scalar; + half zeroPoint; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + return 0; + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + LayoutA layoutA{args.problemShape.m(), args.problemShape.k()}; + LayoutB layoutB{args.problemShape.k(), args.problemShape.n()}; + LayoutC layoutC{args.problemShape.m(), args.problemShape.n()}; + Params params{args.problemShape, args.ptrA, layoutA, args.ptrB, layoutB, args.ptrC, + layoutC, args.ptrWA, args.ptrWB, args.ptrWC, args.scalar, args.zeroPoint}; + return params; + } + + // Methods + CATLASS_DEVICE + FP8Matmul() + { + flag0[0].id = flagID0; + flag0[1].id = flagID1; + flag1[0].id = flagID2; + flag1[1].id = flagID3; + } + + /// Executes one GEMM + template + CATLASS_DEVICE __attribute__((always_inline)) void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler(params.problemShape, + MakeCoord((L1TileShape::M * mScalar), (L1TileShape::N * nScalar))); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ int8_t *)params.ptrA); + AscendC::GlobalTensor gmWA; + gmWA.SetGlobalBuffer((__gm__ half *)params.ptrWA); + + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ int8_t *)params.ptrB); + AscendC::GlobalTensor gmWB; + gmWB.SetGlobalBuffer((__gm__ half *)params.ptrWB); + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ half *)params.ptrC); + AscendC::GlobalTensor gmWC; + gmWC.SetGlobalBuffer((__gm__ float *)params.ptrWC); + + uint32_t srcAStride = params.problemShape.k(); + uint32_t srcBStride = params.problemShape.n(); + + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + AscendC::SetFlag(EVENT_ID2); + AscendC::SetFlag(EVENT_ID3); + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + AscendC::SetFlag(EVENT_ID2); + AscendC::SetFlag(EVENT_ID3); + for (uint32_t loopIdx = AscendC::GetBlockIdx() / AIVPERCORE; loopIdx < coreLoops; + loopIdx += AscendC::GetBlockNum()) { // 一次for循环完成两个行块或者两个列块的反量化 + // 当前任务块信息 + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + MatrixCoord offsetA{blockCoord.m() * (L1TileShape::M * mScalar), 0}; + MatrixCoord offsetB{0, blockCoord.n() * (L1TileShape::N * nScalar)}; + MatrixCoord offsetC{blockCoord.m() * (L1TileShape::M * mScalar), + blockCoord.n() * (L1TileShape::N * nScalar)}; + int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); + int64_t gmOffsetB = params.layoutB.GetOffset(offsetB); + int64_t gmOffsetC = params.layoutC.GetOffset(offsetC); + + // 下一个任务块的信息 + bool isFirstBlock = (loopIdx == (AscendC::GetBlockIdx() / AIVPERCORE)); + bool hasNextBlock = false; + GemmCoord nextBlockCoord; + GemmCoord nextActualBlockShape; + if (loopIdx + AscendC::GetBlockNum() < coreLoops) { + hasNextBlock = true; + nextBlockCoord = blockScheduler.GetBlockCoord(loopIdx + AscendC::GetBlockNum()); + nextActualBlockShape = blockScheduler.GetActualBlockShape(nextBlockCoord); + } + MatrixCoord offsetNextA{nextBlockCoord.m() * (L1TileShape::M * mScalar), 0}; + MatrixCoord offsetNextB{0, nextBlockCoord.n() * (L1TileShape::N * nScalar)}; + MatrixCoord offsetNextC{nextBlockCoord.m() * (L1TileShape::M * mScalar), + nextBlockCoord.n() * (L1TileShape::N * nScalar)}; + int64_t gmOffsetNextA = params.layoutA.GetOffset(offsetNextA); + int64_t gmOffsetNextB = params.layoutB.GetOffset(offsetNextB); + int64_t gmOffsetNextC = params.layoutC.GetOffset(offsetNextC); + + Arch::Resource resource; + uint32_t kLoop = (params.problemShape.k() + splitkLength - 1) / splitkLength; + for (uint32_t ldk = 0; ldk < kLoop; ldk++) { // 一次for循环完成切K后的一个行块/列块 + + // 反量化后的workspace索引 + int64_t gmOffsetWA = + (AscendC::GetBlockIdx() / AIVPERCORE) * (mScalar * L1TileShape::M) * splitkLength * STAGES + + (mScalar * L1TileShape::M) * splitkLength * crossCoreBufferIndexAIV; + int64_t gmOffsetWB = + (AscendC::GetBlockIdx() / AIVPERCORE) * splitkLength * (nScalar * L1TileShape::N) * STAGES + + splitkLength * (nScalar * L1TileShape::N) * crossCoreBufferIndexAIV; + int64_t gmOffsetNextWA = + (AscendC::GetBlockIdx() / AIVPERCORE) * (mScalar * L1TileShape::M) * splitkLength * STAGES + + (mScalar * L1TileShape::M) * splitkLength * (1 - crossCoreBufferIndexAIV); + int64_t gmOffsetNextWB = + (AscendC::GetBlockIdx() / AIVPERCORE) * splitkLength * (nScalar * L1TileShape::N) * STAGES + + splitkLength * (nScalar * L1TileShape::N) * (1 - crossCoreBufferIndexAIV); + + uint32_t kActual = (params.problemShape.k() < (ldk + 1) * splitkLength) + ? params.problemShape.k() % splitkLength + : splitkLength; + uint32_t kActualAligned = (kActual + 256 - 1) / 256 * 256; + + LayoutA layoutWA(actualBlockShape.m(), kActual, kActualAligned); + LayoutB layoutWB(kActual, actualBlockShape.n(), actualBlockShape.n()); + + if (ldk == 0 && isFirstBlock) { // 第一个任务块的第一个K切块 + Catlass::Arch::CrossCoreWaitFlag(flag0[crossCoreBufferIndexAIV]); + if (std::is_same_v) { // A行优先 + PrologueA prologueA(resource); + prologueA(gmA[gmOffsetA], gmWA[gmOffsetWA], layoutWA, srcAStride, kActualAligned, params.scalar, + params.zeroPoint, bufferIndex); + } else { // A列优先 + srcAStride = params.problemShape.m(); + PrologueA prologueA(resource); + prologueA(gmA[gmOffsetA], gmWA[gmOffsetWA], layoutWA, srcAStride, actualBlockShape.m(), + params.scalar, params.zeroPoint, bufferIndex); + } + if (std::is_same_v) { // B行优先 + PrologueB prologueB(resource); + prologueB(gmB[gmOffsetB], gmWB[gmOffsetWB], layoutWB, srcBStride, actualBlockShape.n(), + params.scalar, params.zeroPoint, bufferIndex); + } else { // B列优先 + srcBStride = params.problemShape.k(); + PrologueB prologueB(resource); + prologueB(gmB[gmOffsetB], gmWB[gmOffsetWB], layoutWB, srcBStride, kActualAligned, params.scalar, + params.zeroPoint, bufferIndex); + } + } + if (ldk < kLoop - 1) { // 后续块 + uint32_t kActualNext = (params.problemShape.k() < (ldk + 2) * splitkLength) + ? params.problemShape.k() % splitkLength + : splitkLength; + uint32_t kActualNextAligned = (kActualNext + 256 - 1) / 256 * 256; + + LayoutA layoutNextWA(actualBlockShape.m(), kActualNext, kActualNextAligned); + LayoutB layoutNextWB(kActualNext, actualBlockShape.n(), actualBlockShape.n()); + + Catlass::Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flag1[crossCoreBufferIndexAIV]); + Catlass::Arch::CrossCoreWaitFlag(flag0[1 - crossCoreBufferIndexAIV]); + if (std::is_same_v) { // A行优先 + PrologueA prologueA(resource); + gmOffsetA += kActual; + prologueA(gmA[gmOffsetA], gmWA[gmOffsetNextWA], layoutNextWA, srcAStride, kActualNextAligned, + params.scalar, params.zeroPoint, bufferIndex); + } else { // A列优先 + srcAStride = params.problemShape.m(); + PrologueA prologueA(resource); + gmOffsetA += kActual * params.problemShape.m(); + prologueA(gmA[gmOffsetA], gmWA[gmOffsetNextWA], layoutNextWA, srcAStride, actualBlockShape.m(), + params.scalar, params.zeroPoint, bufferIndex); + } + if (std::is_same_v) { // B行优先 + PrologueB prologueB(resource); + gmOffsetB += kActual * params.problemShape.n(); + prologueB(gmB[gmOffsetB], gmWB[gmOffsetNextWB], layoutNextWB, srcBStride, actualBlockShape.n(), + params.scalar, params.zeroPoint, bufferIndex); + } else { // B列优先 + srcBStride = params.problemShape.k(); + PrologueB prologueB(resource); + gmOffsetB += kActual; + prologueB(gmB[gmOffsetB], gmWB[gmOffsetNextWB], layoutNextWB, srcBStride, kActualNextAligned, + params.scalar, params.zeroPoint, bufferIndex); + } + } + if ((ldk == kLoop - 1) && hasNextBlock) { // 当前切块为K方向最后一个切块且有下一个任务块 + uint32_t kActualNext = (params.problemShape.k() < splitkLength) + ? params.problemShape.k() % splitkLength + : splitkLength; + uint32_t kActualNextAligned = (kActualNext + 256 - 1) / 256 * 256; + + LayoutA layoutNextWA(nextActualBlockShape.m(), kActualNext, kActualNext); + LayoutB layoutNextWB(kActualNext, nextActualBlockShape.n(), nextActualBlockShape.n()); + + Catlass::Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flag1[crossCoreBufferIndexAIV]); + Catlass::Arch::CrossCoreWaitFlag(flag0[1 - crossCoreBufferIndexAIV]); + if (std::is_same_v) { // A行优先 + PrologueA prologueA(resource); + prologueA(gmA[gmOffsetNextA], gmWA[gmOffsetNextWA], layoutNextWA, srcAStride, + kActualNextAligned, params.scalar, params.zeroPoint, bufferIndex); + } else { // A列优先 + srcAStride = params.problemShape.m(); + PrologueA prologueA(resource); + prologueA(gmA[gmOffsetNextA], gmWA[gmOffsetNextWA], layoutNextWA, srcAStride, + nextActualBlockShape.m(), params.scalar, params.zeroPoint, bufferIndex); + } + if (std::is_same_v) { // B行优先 + PrologueB prologueB(resource); + prologueB(gmB[gmOffsetNextB], gmWB[gmOffsetNextWB], layoutNextWB, srcBStride, + nextActualBlockShape.n(), params.scalar, params.zeroPoint, bufferIndex); + } else { // B列优先 + srcBStride = params.problemShape.k(); + PrologueB prologueB(resource); + prologueB(gmB[gmOffsetNextB], gmWB[gmOffsetNextWB], layoutNextWB, srcBStride, + kActualNextAligned, params.scalar, params.zeroPoint, bufferIndex); + } + + Catlass::Arch::CrossCoreWaitFlag(flag4); + Catlass::layout::RowMajor layoutBlockC(actualBlockShape.m(), actualBlockShape.n(), + params.problemShape.n()); + int64_t gmOffsetWC = + (AscendC::GetBlockIdx() / AIVPERCORE) * (mScalar * L1TileShape::M) * (nScalar * L1TileShape::N); + Cast cast; + cast.castFP32toFP16(gmWC[gmOffsetWC], gmC[gmOffsetC], layoutBlockC, nScalar * L1TileShape::N, + params.problemShape.n()); + } + if ((ldk == kLoop - 1) && (!hasNextBlock)) { // 切块为K方向最后一个切块且没有下一个任务块 + Catlass::Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flag1[crossCoreBufferIndexAIV]); + Catlass::Arch::CrossCoreWaitFlag(flag4); + Catlass::layout::RowMajor layoutBlockC(actualBlockShape.m(), actualBlockShape.n(), + params.problemShape.n()); + int64_t gmOffsetWC = + (AscendC::GetBlockIdx() / AIVPERCORE) * (mScalar * L1TileShape::M) * (nScalar * L1TileShape::N); + Cast cast; + cast.castFP32toFP16(gmWC[gmOffsetWC], gmC[gmOffsetC], layoutBlockC, nScalar * L1TileShape::N, + params.problemShape.n()); + } + crossCoreBufferIndexAIV = 1 - crossCoreBufferIndexAIV; + } + } + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID2); + AscendC::WaitFlag(EVENT_ID3); + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID2); + AscendC::WaitFlag(EVENT_ID3); + + Catlass::Arch::CrossCoreWaitFlag(flag0[0]); + Catlass::Arch::CrossCoreWaitFlag(flag0[1]); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler(params.problemShape, + MakeCoord((L1TileShape::M * mScalar), (L1TileShape::N * nScalar))); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + Arch::Resource resource; + BlockMmad blockMmad(resource); + + AscendC::GlobalTensor gmWA; + gmWA.SetGlobalBuffer((__gm__ half *)params.ptrWA); + AscendC::GlobalTensor gmWB; + gmWB.SetGlobalBuffer((__gm__ half *)params.ptrWB); + AscendC::GlobalTensor gmWC; + gmWC.SetGlobalBuffer((__gm__ float *)params.ptrWC); + + Catlass::Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(flag0[0]); + Catlass::Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(flag0[1]); + + for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; + loopIdx += AscendC::GetBlockNum()) { // 一次for循环完成一个大基本结果块(256,512) + // Compute block location + // 获取当前大基本结果块的左上角坐标以及实际大小 + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBigBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + uint32_t kLoop = (params.problemShape.k() + splitkLength - 1) / splitkLength; + for (uint32_t ldk = 0; ldk < kLoop; ldk++) { // 一次for循环完成切K后的一个行块/列块 + bool isFirstKSlice = (ldk == 0) ? true : false; + uint32_t kActual = (params.problemShape.k() < (ldk + 1) * splitkLength) + ? params.problemShape.k() % splitkLength + : splitkLength; + uint32_t kActualAligned = (kActual + 256 - 1) / 256 * 256; + + uint32_t mLoop = (actualBigBlockShape.m() + L1TileShape::M - 1) / L1TileShape::M; + uint32_t nLoop = (actualBigBlockShape.n() + L1TileShape::N - 1) / L1TileShape::N; + Catlass::Arch::CrossCoreWaitFlag(flag1[crossCoreBufferIndexAIC]); + for (uint32_t processIdm = 0; processIdm < mLoop; processIdm++) { + for (uint32_t processIdn = 0; processIdn < nLoop; processIdn++) { + bool hasNextBlock = ((processIdm * nLoop + processIdn) < mLoop * nLoop - 1) ? true : false; + bool isFirstBlock = (processIdm == 0 && processIdn == 0) ? true : false; + uint32_t processIdxNext = processIdm * nLoop + processIdn + 1; + uint32_t processIdmNext = processIdxNext / nLoop; + uint32_t processIdnNext = processIdxNext % nLoop; + // Compute initial location in logical coordinates + MatrixCoord offsetBlockC{ + blockCoord.m() * (L1TileShape::M * mScalar) + L1TileShape::M * processIdm, + blockCoord.n() * (L1TileShape::N * nScalar) + L1TileShape::N * processIdn}; + + uint32_t mActual = L1TileShape::M; + uint32_t nActual = L1TileShape::N; + if (actualBigBlockShape.m() % L1TileShape::M != 0 && processIdm == mLoop - 1) { + mActual = actualBigBlockShape.m() % L1TileShape::M; + } + if (actualBigBlockShape.n() % L1TileShape::N != 0 && processIdn == nLoop - 1) { + nActual = actualBigBlockShape.n() % L1TileShape::N; + } + GemmCoord actualSmallBlockShape(mActual, nActual, kActual); + + uint32_t mActualNext = L1TileShape::M; + uint32_t nActualNext = L1TileShape::N; + if (actualBigBlockShape.m() % L1TileShape::M != 0 && processIdmNext == mLoop - 1) { + mActualNext = actualBigBlockShape.m() % L1TileShape::M; + } + if (actualBigBlockShape.n() % L1TileShape::N != 0 && processIdnNext == nLoop - 1) { + nActualNext = actualBigBlockShape.n() % L1TileShape::N; + } + GemmCoord nextSmallBlockShape(mActualNext, nActualNext, kActual); + + // 当前块的地址偏移 + int64_t gmOffsetWA = + AscendC::GetBlockIdx() * (L1TileShape::M * mScalar) * splitkLength * STAGES + + (L1TileShape::M * mScalar) * splitkLength * crossCoreBufferIndexAIC + + processIdm * L1TileShape::M * kActualAligned; + int64_t gmOffsetWB = + AscendC::GetBlockIdx() * splitkLength * (L1TileShape::N * nScalar) * STAGES + + splitkLength * (L1TileShape::N * nScalar) * crossCoreBufferIndexAIC + + processIdn * L1TileShape::N; + int64_t gmOffsetWC = + AscendC::GetBlockIdx() * (L1TileShape::M * mScalar) * (L1TileShape::N * nScalar) + + processIdm * L1TileShape::M * (L1TileShape::N * nScalar) + processIdn * L1TileShape::N; + + uint32_t AStride = kActualAligned; + uint32_t BStride = actualBigBlockShape.n(); + if (std::is_same_v) { // A列优先 + gmOffsetWA = AscendC::GetBlockIdx() * (L1TileShape::M * mScalar) * splitkLength * STAGES + + (L1TileShape::M * mScalar) * splitkLength * crossCoreBufferIndexAIC + + processIdm * L1TileShape::M; + AStride = actualBigBlockShape.m(); + } + if (std::is_same_v) { // B列优先 + gmOffsetWB = AscendC::GetBlockIdx() * splitkLength * (L1TileShape::N * nScalar) * STAGES + + splitkLength * (L1TileShape::N * nScalar) * crossCoreBufferIndexAIC + + processIdn * L1TileShape::N * kActualAligned; + BStride = kActualAligned; + } + + // 下一个块的地址偏移 + int64_t gmOffsetWANext = + AscendC::GetBlockIdx() * (L1TileShape::M * mScalar) * splitkLength * STAGES + + (L1TileShape::M * mScalar) * splitkLength * crossCoreBufferIndexAIC + + processIdmNext * L1TileShape::M * kActualAligned; + int64_t gmOffsetWBNext = + AscendC::GetBlockIdx() * splitkLength * (L1TileShape::N * nScalar) * STAGES + + splitkLength * (L1TileShape::N * nScalar) * crossCoreBufferIndexAIC + + processIdnNext * L1TileShape::N; + int64_t gmOffsetWCNext = + AscendC::GetBlockIdx() * (L1TileShape::M * mScalar) * (L1TileShape::N * nScalar) + + processIdmNext * L1TileShape::M * (L1TileShape::N * nScalar) + + processIdnNext * L1TileShape::N; + + if (std::is_same_v) { // A列优先 + gmOffsetWANext = + AscendC::GetBlockIdx() * (L1TileShape::M * mScalar) * splitkLength * STAGES + + (L1TileShape::M * mScalar) * splitkLength * crossCoreBufferIndexAIC + + processIdmNext * L1TileShape::M; + AStride = actualBigBlockShape.m(); + } + if (std::is_same_v) { // B列优先 + gmOffsetWBNext = + AscendC::GetBlockIdx() * splitkLength * (L1TileShape::N * nScalar) * STAGES + + splitkLength * (L1TileShape::N * nScalar) * crossCoreBufferIndexAIC + + processIdnNext * L1TileShape::N * kActualAligned; + BStride = kActualAligned; + } + + LayoutA layoutWA(mActual, kActual, AStride); + LayoutB layoutWB(kActual, nActual, BStride); + LayoutC layoutWC(mActual, nActual, nScalar * L1TileShape::N); + + LayoutA layoutWANext(mActualNext, kActual, AStride); + LayoutB layoutWBNext(kActual, nActualNext, BStride); + LayoutC layoutWCNext(mActualNext, nActualNext, nScalar * L1TileShape::N); + + // 完成一个128 * 256的小结果矩阵基本块的运算 + blockMmad(gmWA[gmOffsetWA], layoutWA, gmWB[gmOffsetWB], layoutWB, gmWC[gmOffsetWC], layoutWC, + gmWA[gmOffsetWANext], layoutWANext, gmWB[gmOffsetWBNext], layoutWBNext, + actualSmallBlockShape, nextSmallBlockShape, isFirstKSlice, isFirstBlock, + hasNextBlock); + } + } + Catlass::Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(flag0[crossCoreBufferIndexAIC]); + crossCoreBufferIndexAIC = 1 - crossCoreBufferIndexAIC; + if (ldk == kLoop - 1) { + // cast 256 * 512的fp32大结果基本块为fp16 + Catlass::Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(flag4); + } + } + } + } + +protected: + static constexpr uint32_t STAGES = 2; + static constexpr uint32_t AIVPERCORE = 2; + int32_t l1AEventList[STAGES]; + int32_t l1BEventList[STAGES]; + int32_t l0AEventList[STAGES]; + int32_t l0BEventList[STAGES]; + + static constexpr Arch::FlagID flagID0 = 0; + static constexpr Arch::FlagID flagID1 = 1; + static constexpr Arch::FlagID flagID2 = 2; + static constexpr Arch::FlagID flagID3 = 3; + static constexpr Arch::FlagID flagID4 = 4; + + Arch::CrossCoreFlag flag0[STAGES]; + Arch::CrossCoreFlag flag1[STAGES]; + Arch::CrossCoreFlag flag4{flagID4}; + + uint32_t crossCoreBufferIndexAIC{0}; + uint32_t crossCoreBufferIndexAIV{0}; + uint32_t bufferIndex{0}; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_FP8_MATMUL_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/gemm.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/gemm.hpp new file mode 100644 index 00000000..29f2aae2 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/gemm.hpp @@ -0,0 +1,433 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_GEMM_HPP +#define CATLASS_GEMM_KERNEL_GEMM_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/epilogue/tile/copy_gm_to_ub.hpp" +#include "catlass/epilogue/tile/copy_ub_to_gm.hpp" +#include "catlass/gemm/helper.hpp" + +namespace Catlass::Gemm::Kernel { + +template +struct PaddingMatrixND { +public: + using ArchTag = ArchTag_; + using Element = Element_; + using Layout = Layout_; + using CopyGm2Ub = Catlass::Epilogue::Tile::CopyGm2Ub>; + using CopyUb2Gm = Catlass::Epilogue::Tile::CopyUb2Gm>; + using ComputeLayout = Catlass::layout::RowMajor; + + CopyGm2Ub copyGm2Ub; + CopyUb2Gm copyUb2Gm; + + CATLASS_DEVICE + PaddingMatrixND(Arch::Resource &resource) + { + int64_t bufferOffset = 0; + for (uint32_t i = 0; i < BUFFER_NUM; i++) { // + inputBuffer[i] = resource.ubBuf.template GetBufferByByte(bufferOffset * sizeof(Element)); + bufferOffset += COMPUTE_LENGTH; + } + } + + CATLASS_DEVICE + ComputeLayout GetPaddingComputeLayout(layout::RowMajor const &layout) + { + return ComputeLayout(layout.shape(0), layout.shape(1), layout.stride(0)); + } + + CATLASS_DEVICE + ComputeLayout GetPaddingComputeLayout(layout::ColumnMajor const &layout) + { + return ComputeLayout(layout.shape(1), layout.shape(0), layout.stride(1)); + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &dst, AscendC::GlobalTensor const &src, + Layout layoutDst, Layout layoutSrc) + { + ComputeLayout computeLayoutSrc = GetPaddingComputeLayout(layoutSrc); + ComputeLayout computeLayoutDst = GetPaddingComputeLayout(layoutDst); + + uint32_t aivNum = AscendC::GetBlockNum() * AscendC::GetSubBlockNum(); + uint32_t aivId = AscendC::GetBlockIdx(); + + // Each line is a tile. + uint32_t tilesNum = computeLayoutSrc.shape(0); + uint32_t tileLen = computeLayoutSrc.shape(1); + uint32_t paddingStride = computeLayoutDst.stride(0); + + uint32_t tilesPerAiv = tilesNum / aivNum; + uint32_t tileRemain = tilesNum % aivNum; + if (aivId < tileRemain) { + tilesPerAiv++; + } + uint32_t mIdx = aivId * tilesPerAiv; + if (aivId >= tileRemain) { + mIdx += tileRemain; + } + MatrixCoord blockOffset(mIdx, 0); + + AscendC::SetFlag(eventIds[0]); + AscendC::SetFlag(eventIds[1]); + uint32_t coreLoops{0}; + if (paddingStride > COMPUTE_LENGTH) { + // Handle the same tile on multiple loops. + uint32_t loopsPerTile = CeilDiv(tileLen, COMPUTE_LENGTH); + coreLoops = tilesPerAiv * loopsPerTile; + for (uint32_t loopIdx = 0; loopIdx < coreLoops; ++loopIdx) { + uint32_t tileIdx = loopIdx / loopsPerTile; + uint32_t inTileLoopIdx = loopIdx % loopsPerTile; + MatrixCoord loopOffset(tileIdx, inTileLoopIdx * COMPUTE_LENGTH); + uint64_t gmSrcOffset = computeLayoutSrc.GetOffset(blockOffset + loopOffset); + uint32_t actualDataNum = COMPUTE_LENGTH; + if (tileLen - inTileLoopIdx * COMPUTE_LENGTH < COMPUTE_LENGTH) { + actualDataNum = tileLen - inTileLoopIdx * COMPUTE_LENGTH; + } + AscendC::WaitFlag(eventIds[bufferIndex]); + ComputeLayout dstLayout = computeLayoutDst.GetTileLayout(MatrixCoord(1, actualDataNum)); + ComputeLayout srcLayout = computeLayoutSrc.GetTileLayout(MatrixCoord(1, actualDataNum)); + ComputeLayout &ubLayout = dstLayout; + copyGm2Ub(inputBuffer[bufferIndex], src[gmSrcOffset], ubLayout, srcLayout); + AscendC::SetFlag(eventIds[bufferIndex]); + AscendC::WaitFlag(eventIds[bufferIndex]); + uint64_t gmDstOffset = computeLayoutDst.GetOffset(blockOffset + loopOffset); + copyUb2Gm(dst[gmDstOffset], inputBuffer[bufferIndex], dstLayout, ubLayout); + AscendC::SetFlag(eventIds[bufferIndex]); + bufferIndex = 1 - bufferIndex; + } + } else { + // Handle multiple tile each loop. + uint32_t tilesPerLoop = COMPUTE_LENGTH / paddingStride; + coreLoops = CeilDiv(tilesPerAiv, tilesPerLoop); + for (uint32_t loopIdx = 0; loopIdx < coreLoops; ++loopIdx) { + uint32_t tileIdx = loopIdx * tilesPerLoop; + MatrixCoord tileOffset(tileIdx, 0); + uint64_t gmSrcOffset = computeLayoutSrc.GetOffset(blockOffset + tileOffset); + uint32_t actualTilesNum = tilesPerLoop; + if (tilesPerAiv - tileIdx < tilesPerLoop) { + actualTilesNum = tilesPerAiv - tileIdx; + } + AscendC::WaitFlag(eventIds[bufferIndex]); + ComputeLayout dstLayout = computeLayoutDst.GetTileLayout(MatrixCoord(actualTilesNum, tileLen)); + ComputeLayout srcLayout = computeLayoutSrc.GetTileLayout(MatrixCoord(actualTilesNum, tileLen)); + ComputeLayout &ubLayout = dstLayout; + copyGm2Ub(inputBuffer[bufferIndex], src[gmSrcOffset], ubLayout, srcLayout); + AscendC::SetFlag(eventIds[bufferIndex]); + AscendC::WaitFlag(eventIds[bufferIndex]); + uint64_t gmDstOffset = computeLayoutDst.GetOffset(blockOffset + tileOffset); + copyUb2Gm(dst[gmDstOffset], inputBuffer[bufferIndex], dstLayout, ubLayout); + AscendC::SetFlag(eventIds[bufferIndex]); + bufferIndex = 1 - bufferIndex; + } + } + AscendC::WaitFlag(eventIds[0]); + AscendC::WaitFlag(eventIds[1]); + } + + CATLASS_DEVICE + ~PaddingMatrixND() {} + +private: + static const uint32_t BUFFER_NUM = 2; + AscendC::LocalTensor inputBuffer[BUFFER_NUM]; + AscendC::TEventID eventIds[BUFFER_NUM] = {EVENT_ID0, EVENT_ID1}; + uint32_t bufferIndex{0}; + static_assert(BUFFER_NUM * COMPUTE_LENGTH * sizeof(Element) <= ArchTag::UB_SIZE, "Exceeding the UB space!"); +}; + +template +class KernelGemm +{ +public: + using BlockGemm = BlockGemm_; + using ArchTag = typename BlockGemm::ArchTag; + using L1TileShape = typename BlockGemm::L1TileShape; + using ElementA = typename BlockGemm::ElementA; + using LayoutA = typename BlockGemm::LayoutA; + using LayoutWA = typename BlockGemm::LayoutA; + using ElementB = typename BlockGemm::ElementB; + using LayoutB = typename BlockGemm::LayoutB; + using LayoutWB = typename BlockGemm::LayoutB; + using ElementC = typename BlockGemm::ElementC; + using LayoutC = typename BlockGemm::LayoutC; + using ElementAccumulator = typename BlockGemm::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using EpilogueParams = typename BlockEpilogue::Params; + + const uint32_t maxMPerBlock = L1TileShape::M; + const uint32_t maxNPerBlock = L1TileShape::N; + const uint32_t cSize = maxMPerBlock * maxNPerBlock * sizeof(ElementAccumulator); + const uint32_t l0CBlockNum = ArchTag::L0C_SIZE / cSize; + using ElementCompute = + typename Catlass::Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using ElementScalar = ElementCompute; + static constexpr uint32_t STAGES = BlockGemm::STAGES; + using BlockScheduler = BlockScheduler_; + + static const uint32_t COMPUTE_LENGTH_A = 96 * 1024 / sizeof(ElementA); + using PaddingA = PaddingMatrixND; + static const uint32_t COMPUTE_LENGTH_B = 96 * 1024 / sizeof(ElementB); + using PaddingB = PaddingMatrixND; + + struct Params { + GemmCoord problemShape; + GM_ADDR ptrA; + LayoutA layoutA; + GM_ADDR ptrB; + LayoutB layoutB; + GM_ADDR gmWorkspace; + GM_ADDR ptrWA; + LayoutA layoutWA; + GM_ADDR ptrWB; + LayoutB layoutWB; + EpilogueParams epilogueParams; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, + GM_ADDR gmWorkspace_, GM_ADDR ptrWA_, LayoutA layoutWA_, GM_ADDR ptrWB_, LayoutB layoutWB_, + EpilogueParams epilogueParams_) + : problemShape(problemShape_), + ptrA(ptrA_), + layoutA(layoutA_), + ptrB(ptrB_), + layoutB(layoutB_), + gmWorkspace(gmWorkspace_), + ptrWA(ptrWA_), + layoutWA(layoutWA_), + ptrWB(ptrWB_), + layoutWB(layoutWB_), + epilogueParams(epilogueParams_) + {} + }; + + struct Arguments { + GemmCoord problemShape; + uint32_t align; + GM_ADDR ptrA; + GM_ADDR ptrB; + GM_ADDR gmWorkspace; + GM_ADDR ptrWA; + GM_ADDR ptrWB; + EpilogueParams epilogueParams; + }; + + static layout::RowMajor GetWorkspaceLayout(layout::RowMajor layout, uint32_t align) + { + if (align == 0) { + return layout; + } + return layout::RowMajor(layout.shape(0), layout.shape(1), RoundUp(layout.shape(1), align)); + } + + static layout::ColumnMajor GetWorkspaceLayout(layout::ColumnMajor layout, uint32_t align) + { + if (align == 0) { + return layout; + } + return layout::ColumnMajor(layout.shape(0), layout.shape(1), RoundUp(layout.shape(0), align)); + } + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + return 0; + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + LayoutA layoutA{args.problemShape.m(), args.problemShape.k()}; + LayoutB layoutB{args.problemShape.k(), args.problemShape.n()}; + LayoutWA layoutWA = GetWorkspaceLayout(layoutA, args.align); + LayoutWB layoutWB = GetWorkspaceLayout(layoutB, args.align); + Params params{args.problemShape, args.ptrA, layoutA, args.ptrB, layoutB, + args.gmWorkspace, args.ptrWA, layoutWA, args.ptrWB, layoutWB, + args.epilogueParams}; + return params; + } + + CATLASS_DEVICE + bool IsSameStride(layout::RowMajor layout1, layout::RowMajor layout2) + { + return layout1.stride(0) == layout2.stride(0); + } + CATLASS_DEVICE + bool IsSameStride(layout::ColumnMajor layout1, layout::ColumnMajor layout2) + { + return layout1.stride(1) == layout2.stride(1); + } + + CATLASS_DEVICE + KernelGemm() {} + + CATLASS_DEVICE + ~KernelGemm() {} + + template + CATLASS_DEVICE void operator()(Params ¶ms) + {} + + template <> + CATLASS_DEVICE void operator()(Params ¶ms) + { + if (!IsSameStride(params.layoutWA, params.layoutA) || !IsSameStride(params.layoutWB, params.layoutB)) { + Arch::CrossCoreWaitFlag(flagAivFinishPadding); + } + Arch::Resource resource; + BlockGemm blockGemm(resource); + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrWA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrWB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.gmWorkspace); + uint32_t M = params.problemShape.m(); + uint32_t N = params.problemShape.n(); + uint32_t K = params.problemShape.k(); +#pragma unroll + for (uint32_t i = 0; i < l0CBlockNum; i++) { + AscendC::SetFlag((int32_t)i); + } + uint32_t mLoops = CeilDiv(M, maxMPerBlock); + uint32_t nLoops = CeilDiv(N, maxNPerBlock); + uint32_t coreLoops = mLoops * nLoops; + uint32_t singleIdx = 0; + LayoutC layoutC(params.problemShape.m(), params.problemShape.n()); + for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + uint32_t mGmBlockIdx = loopIdx / nLoops; + uint32_t nGmBlockIdx = loopIdx % nLoops; + uint32_t mGmActual = (mGmBlockIdx == mLoops - 1) ? (M - mGmBlockIdx * maxMPerBlock) : maxMPerBlock; + uint32_t nGmActual = (nGmBlockIdx == nLoops - 1) ? (N - nGmBlockIdx * maxNPerBlock) : maxNPerBlock; + bool isFirstBlock = (loopIdx == AscendC::GetBlockIdx()); + bool hasNextBlock = false; + GemmCoord nextActualShape; + uint32_t mNextGmBlockIdx = 0; + uint32_t nNextGmBlockIdx = 0; + if (loopIdx + AscendC::GetBlockNum() < coreLoops) { + hasNextBlock = true; + uint32_t nextLoopIdx = loopIdx + AscendC::GetBlockNum(); + mNextGmBlockIdx = nextLoopIdx / nLoops; + nNextGmBlockIdx = nextLoopIdx % nLoops; + uint32_t mNextGmActual = + (mNextGmBlockIdx == mLoops - 1) ? (M - mNextGmBlockIdx * maxMPerBlock) : maxMPerBlock; + uint32_t nNextGmActual = + (nNextGmBlockIdx == nLoops - 1) ? (N - nNextGmBlockIdx * maxNPerBlock) : maxNPerBlock; + nextActualShape = MakeCoord(mNextGmActual, nNextGmActual, K); + } + GemmCoord actualShape{mGmActual, nGmActual, K}; + AscendC::WaitFlag((int32_t)singleIdx); + MatrixCoord gmTileAOffset{mGmBlockIdx * maxMPerBlock, 0}; + auto gmTileA = gmA[params.layoutWA.GetOffset(gmTileAOffset)]; + MatrixCoord gmTileBOffset{0, nGmBlockIdx * maxNPerBlock}; + auto gmTileB = gmB[params.layoutWB.GetOffset(gmTileBOffset)]; + MatrixCoord gmTileCOffset{mGmBlockIdx * maxMPerBlock, nGmBlockIdx * maxNPerBlock}; + auto gmTileC = gmC[layoutC.GetOffset(gmTileCOffset)]; + MatrixCoord gmTileNextAOffset{mNextGmBlockIdx * maxMPerBlock, 0}; + auto gmTileNextA = gmA[params.layoutWA.GetOffset(gmTileNextAOffset)]; + MatrixCoord gmTileNextBOffset{0, nNextGmBlockIdx * maxNPerBlock}; + auto gmTileNextB = gmB[params.layoutWB.GetOffset(gmTileNextBOffset)]; + blockGemm(gmTileA, params.layoutWA, gmTileB, params.layoutWB, gmTileC, layoutC, gmTileNextA, gmTileNextB, + actualShape, nextActualShape, isFirstBlock, hasNextBlock, singleIdx); + Arch::CrossCoreSetFlagWithReverse<0x2, PIPE_FIX>(flagAicFinishStore); + AscendC::SetFlag((int32_t)singleIdx); + singleIdx = (singleIdx + 1) % l0CBlockNum; + } +#pragma unroll + for (uint32_t i = 0; i < l0CBlockNum; i++) { + AscendC::WaitFlag((int32_t)i); + } + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params ¶ms) + { + Arch::Resource resource; + uint64_t inGroupOffsetWorkspace = 0; + if (!IsSameStride(params.layoutWA, params.layoutA)) { + AscendC::GlobalTensor gmA; + AscendC::GlobalTensor gmWA; + gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(params.ptrA)); + gmWA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(params.ptrWA)); + PaddingA paddingA(resource); + paddingA(gmWA, gmA, params.layoutWA, params.layoutA); + } + + if (!IsSameStride(params.layoutWB, params.layoutB)) { + AscendC::GlobalTensor gmB; + AscendC::GlobalTensor gmWB; + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(params.ptrB)); + gmWB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(params.ptrWB)); + PaddingB paddingB(resource); + paddingB(gmWB, gmB, params.layoutWB, params.layoutB); + // 0x0 synchronization control between AI Core + } + if (!IsSameStride(params.layoutWA, params.layoutA) || !IsSameStride(params.layoutWB, params.layoutB)) { + Catlass::Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + Catlass::Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishPadding); + } + GemmCoord blockShape = L1TileShape::ToCoord(); + BlockEpilogue blockEpilogue(resource, blockShape, params.epilogueParams); + uint32_t M = params.problemShape.m(); + uint32_t N = params.problemShape.n(); + uint32_t K = params.problemShape.k(); + uint32_t mLoops = CeilDiv(M, maxMPerBlock); + uint32_t nLoops = CeilDiv(N, maxNPerBlock); + uint32_t coreLoops = mLoops * nLoops; + uint32_t aivNum = AscendC::GetSubBlockNum(); + uint32_t aivIndex = AscendC::GetBlockIdx(); + uint32_t aicoreIndex = aivIndex / aivNum; + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.gmWorkspace); + for (uint32_t loopIdx = aicoreIndex; loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + uint32_t mGmBlockIdx = loopIdx / nLoops; + uint32_t nGmBlockIdx = loopIdx % nLoops; + uint32_t mGmActual = (mGmBlockIdx == mLoops - 1) ? (M - mGmBlockIdx * maxMPerBlock) : maxMPerBlock; + uint32_t nGmActual = (nGmBlockIdx == nLoops - 1) ? (N - nGmBlockIdx * maxNPerBlock) : maxNPerBlock; + GemmCoord actualShape{mGmActual, nGmActual, K}; + GemmCoord blockCoord{mGmBlockIdx, nGmBlockIdx, 0}; + LayoutC layoutC(params.problemShape.m(), params.problemShape.n()); + Arch::CrossCoreWaitFlagWithReverse<0x2, PIPE_MTE3>(flagAicFinishStore); + blockEpilogue(actualShape, blockCoord, gmC, layoutC, inGroupOffsetWorkspace); + } + inGroupOffsetWorkspace += params.problemShape.m() * params.problemShape.n(); + + AscendC::PipeBarrier(); + } + +private: + static constexpr Arch::FlagID FLAG_AIC_FINISH_STORE = 0; + static constexpr Arch::FlagID RV_FLAG_AIC_FINISH_STORE = 1; + Arch::CrossCoreFlagWithReverse<> flagAicFinishStore{FLAG_AIC_FINISH_STORE, RV_FLAG_AIC_FINISH_STORE}; + static constexpr Arch::FlagID FLAG_AIV_FINISH_STORE = 0; + Arch::CrossCoreFlag flagAivFinishPadding{FLAG_AIV_FINISH_STORE}; +}; +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_GEMM_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/group_gemm.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/group_gemm.hpp new file mode 100644 index 00000000..5020792d --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/group_gemm.hpp @@ -0,0 +1,516 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_GROUPGEMM_HPP +#define CATLASS_GEMM_KERNEL_GROUPGEMM_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/epilogue/tile/copy_gm_to_ub.hpp" +#include "catlass/epilogue/tile/copy_ub_to_gm.hpp" +#include "catlass/gemm/helper.hpp" + +namespace Catlass::Gemm::Kernel { + +namespace detail { + +template +CATLASS_DEVICE void UnpackListParam(T *const dst, GM_ADDR src, uint32_t len) +{ + for (uint32_t i = 0; i * sizeof(uint64_t) < len * sizeof(T); ++i) { + reinterpret_cast(dst)[i] = reinterpret_cast<__gm__ uint64_t *>(src)[i]; + } +} + +} // namespace detail + +template +struct PaddingMatrixND { +public: + using ArchTag = ArchTag_; + using Element = Element_; + using Layout = Layout_; + using CopyGm2Ub = Catlass::Epilogue::Tile::CopyGm2Ub>; + using CopyUb2Gm = Catlass::Epilogue::Tile::CopyUb2Gm>; + using ComputeLayout = Catlass::layout::RowMajor; + + CopyGm2Ub copyGm2Ub; + CopyUb2Gm copyUb2Gm; + + CATLASS_DEVICE + PaddingMatrixND(Arch::Resource &resource) + { + int64_t bufferOffset = 0; + for (uint32_t i = 0; i < BUFFER_NUM; i++) { // + inputBuffer[i] = resource.ubBuf.template GetBufferByByte(bufferOffset * sizeof(Element)); + bufferOffset += COMPUTE_LENGTH; + } + } + + CATLASS_DEVICE + ComputeLayout GetPaddingComputeLayout(layout::RowMajor const &layout) + { + return ComputeLayout(layout.shape(0), layout.shape(1), layout.stride(0)); + } + + CATLASS_DEVICE + ComputeLayout GetPaddingComputeLayout(layout::ColumnMajor const &layout) + { + return ComputeLayout(layout.shape(1), layout.shape(0), layout.stride(1)); + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &dst, AscendC::GlobalTensor const &src, + Layout layoutDst, Layout layoutSrc) + { + ComputeLayout computeLayoutSrc = GetPaddingComputeLayout(layoutSrc); + ComputeLayout computeLayoutDst = GetPaddingComputeLayout(layoutDst); + + uint32_t aivNum = AscendC::GetBlockNum() * AscendC::GetSubBlockNum(); + uint32_t aivId = AscendC::GetBlockIdx(); + + // Each line is a tile. + uint32_t tilesNum = computeLayoutSrc.shape(0); + uint32_t tileLen = computeLayoutSrc.shape(1); + uint32_t paddingStride = computeLayoutDst.stride(0); + + uint32_t tilesPerAiv = tilesNum / aivNum; + uint32_t tileRemain = tilesNum % aivNum; + if (aivId < tileRemain) { + tilesPerAiv++; + } + uint32_t mIdx = aivId * tilesPerAiv; + if (aivId >= tileRemain) { + mIdx += tileRemain; + } + MatrixCoord blockOffset(mIdx, 0); + + AscendC::SetFlag(eventIds[0]); + AscendC::SetFlag(eventIds[1]); + uint32_t coreLoops{0}; + if (paddingStride > COMPUTE_LENGTH) { + // Handle the same tile on multiple loops. + uint32_t loopsPerTile = CeilDiv(tileLen, COMPUTE_LENGTH); + coreLoops = tilesPerAiv * loopsPerTile; + for (uint32_t loopIdx = 0; loopIdx < coreLoops; ++loopIdx) { + uint32_t tileIdx = loopIdx / loopsPerTile; + uint32_t inTileLoopIdx = loopIdx % loopsPerTile; + MatrixCoord loopOffset(tileIdx, inTileLoopIdx * COMPUTE_LENGTH); + uint64_t gmSrcOffset = computeLayoutSrc.GetOffset(blockOffset + loopOffset); + uint32_t actualDataNum = COMPUTE_LENGTH; + if (tileLen - inTileLoopIdx * COMPUTE_LENGTH < COMPUTE_LENGTH) { + actualDataNum = tileLen - inTileLoopIdx * COMPUTE_LENGTH; + } + AscendC::WaitFlag(eventIds[bufferIndex]); + ComputeLayout dstLayout = computeLayoutDst.GetTileLayout(MatrixCoord(1, actualDataNum)); + ComputeLayout srcLayout = computeLayoutSrc.GetTileLayout(MatrixCoord(1, actualDataNum)); + ComputeLayout &ubLayout = dstLayout; + copyGm2Ub(inputBuffer[bufferIndex], src[gmSrcOffset], ubLayout, srcLayout); + AscendC::SetFlag(eventIds[bufferIndex]); + AscendC::WaitFlag(eventIds[bufferIndex]); + uint64_t gmDstOffset = computeLayoutDst.GetOffset(blockOffset + loopOffset); + copyUb2Gm(dst[gmDstOffset], inputBuffer[bufferIndex], dstLayout, ubLayout); + AscendC::SetFlag(eventIds[bufferIndex]); + bufferIndex = 1 - bufferIndex; + } + } else { + // Handle multiple tile each loop. + uint32_t tilesPerLoop = COMPUTE_LENGTH / paddingStride; + coreLoops = CeilDiv(tilesPerAiv, tilesPerLoop); + for (uint32_t loopIdx = 0; loopIdx < coreLoops; ++loopIdx) { + uint32_t tileIdx = loopIdx * tilesPerLoop; + MatrixCoord tileOffset(tileIdx, 0); + uint64_t gmSrcOffset = computeLayoutSrc.GetOffset(blockOffset + tileOffset); + uint32_t actualTilesNum = tilesPerLoop; + if (tilesPerAiv - tileIdx < tilesPerLoop) { + actualTilesNum = tilesPerAiv - tileIdx; + } + AscendC::WaitFlag(eventIds[bufferIndex]); + ComputeLayout dstLayout = computeLayoutDst.GetTileLayout(MatrixCoord(actualTilesNum, tileLen)); + ComputeLayout srcLayout = computeLayoutSrc.GetTileLayout(MatrixCoord(actualTilesNum, tileLen)); + ComputeLayout &ubLayout = dstLayout; + copyGm2Ub(inputBuffer[bufferIndex], src[gmSrcOffset], ubLayout, srcLayout); + AscendC::SetFlag(eventIds[bufferIndex]); + AscendC::WaitFlag(eventIds[bufferIndex]); + uint64_t gmDstOffset = computeLayoutDst.GetOffset(blockOffset + tileOffset); + copyUb2Gm(dst[gmDstOffset], inputBuffer[bufferIndex], dstLayout, ubLayout); + AscendC::SetFlag(eventIds[bufferIndex]); + bufferIndex = 1 - bufferIndex; + } + } + AscendC::WaitFlag(eventIds[0]); + AscendC::WaitFlag(eventIds[1]); + } + + CATLASS_DEVICE + ~PaddingMatrixND() {} + +private: + static const uint32_t BUFFER_NUM = 2; + AscendC::LocalTensor inputBuffer[BUFFER_NUM]; + AscendC::TEventID eventIds[BUFFER_NUM] = {EVENT_ID0, EVENT_ID1}; + uint32_t bufferIndex{0}; + static_assert(BUFFER_NUM * COMPUTE_LENGTH * sizeof(Element) <= ArchTag::UB_SIZE, "Exceeding the UB space!"); +}; + +template +class KernelGroupGemm +{ +public: + using BlockGemm = BlockGemm_; + using ArchTag = typename BlockGemm::ArchTag; + using L1TileShape = typename BlockGemm::L1TileShape; + using ElementA = typename BlockGemm::ElementA; + using LayoutA = typename BlockGemm::LayoutA; + using LayoutWA = typename BlockGemm::LayoutA; + using ElementB = typename BlockGemm::ElementB; + using LayoutB = typename BlockGemm::LayoutB; + using LayoutWB = typename BlockGemm::LayoutB; + using ElementC = typename BlockGemm::ElementC; + using LayoutC = typename BlockGemm::LayoutC; + using ElementAccumulator = typename BlockGemm::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using EpilogueParams = typename BlockEpilogue::Params; + using ElementCompute = + typename Catlass::Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using ElementScalar = ElementCompute; + static constexpr uint32_t MAX_TENSOR_COUNT = 32; + + const uint32_t maxMPerBlock = L1TileShape::M; + const uint32_t maxNPerBlock = L1TileShape::N; + const uint32_t cSize = maxMPerBlock * maxNPerBlock * sizeof(ElementAccumulator); + const uint32_t l0CBlockNum = ArchTag::L0C_SIZE / cSize; + + static constexpr uint32_t STAGES = BlockGemm::STAGES; + using BlockScheduler = BlockScheduler_; + + static const uint32_t COMPUTE_LENGTH_A = 96 * 1024 / sizeof(ElementA); + using PaddingA = PaddingMatrixND; + static const uint32_t COMPUTE_LENGTH_B = 96 * 1024 / sizeof(ElementB); + using PaddingB = PaddingMatrixND; + + struct Params { + // Data members + uint32_t problemCount; + GM_ADDR ptrProblemShape; + GM_ADDR alpha; + GM_ADDR beta; + GM_ADDR ptrA; + GM_ADDR ptrLayoutA; + GM_ADDR ptrB; + GM_ADDR ptrLayoutB; + GM_ADDR ptrWorkspace; + GM_ADDR ptrLayoutWorkspace; + GM_ADDR ptrWA; + GM_ADDR ptrLayoutWA; + GM_ADDR ptrWB; + GM_ADDR ptrLayoutWB; + GM_ADDR ptrX; + GM_ADDR ptrD; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(uint32_t problemCount_, GM_ADDR ptrProblemShape_, GM_ADDR alpha_, GM_ADDR beta_, GM_ADDR ptrA_, + GM_ADDR ptrLayoutA_, GM_ADDR ptrB_, GM_ADDR ptrLayoutB_, GM_ADDR ptrWorkspace_, + GM_ADDR ptrLayoutWorkspace_, GM_ADDR ptrWA_, GM_ADDR ptrLayoutWA_, GM_ADDR ptrWB_, GM_ADDR ptrLayoutWB_, + GM_ADDR ptrX_, GM_ADDR ptrD_) + : problemCount(problemCount_), + ptrProblemShape(ptrProblemShape_), + alpha(alpha_), + beta(beta_), + ptrA(ptrA_), + ptrLayoutA(ptrLayoutA_), + ptrB(ptrB_), + ptrLayoutB(ptrLayoutB_), + ptrWorkspace(ptrWorkspace_), + ptrLayoutWorkspace(ptrLayoutWorkspace_), + ptrWA(ptrWA_), + ptrLayoutWA(ptrLayoutWA_), + ptrWB(ptrWB_), + ptrLayoutWB(ptrLayoutWB_), + ptrX(ptrX_), + ptrD(ptrD_) + {} + }; + + struct Arguments { + uint32_t problemCount; + GM_ADDR ptrProblemShape; + GM_ADDR alpha; + GM_ADDR beta; + GM_ADDR ptrA; + GM_ADDR ptrLayoutA; + GM_ADDR ptrB; + GM_ADDR ptrLayoutB; + GM_ADDR ptrWorkspace; + GM_ADDR ptrLayoutWorkspace; + GM_ADDR ptrWA; + GM_ADDR ptrLayoutWA; + GM_ADDR ptrWB; + GM_ADDR ptrLayoutWB; + GM_ADDR ptrX; + GM_ADDR ptrD; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + return 0; + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + Params params{args.problemCount, args.ptrProblemShape, args.alpha, args.beta, + args.ptrA, args.ptrLayoutA, args.ptrB, args.ptrLayoutB, + args.ptrWorkspace, args.ptrLayoutWorkspace, args.ptrWA, args.ptrLayoutWA, + args.ptrWB, args.ptrLayoutWB, args.ptrX, args.ptrD}; + return params; + } + + CATLASS_DEVICE + KernelGroupGemm() {} + + CATLASS_DEVICE + ~KernelGroupGemm() {} + + CATLASS_DEVICE + size_t GetWorkspaceLen(layout::RowMajor layout) + { + return layout.shape(0) * layout.stride(0); + } + + CATLASS_DEVICE + size_t GetWorkspaceLen(layout::ColumnMajor layout) + { + return layout.shape(1) * layout.stride(1); + } + + template + CATLASS_DEVICE void operator()(Params ¶ms) + {} + + template <> + CATLASS_DEVICE void operator()(Params ¶ms) + { + GemmCoord problemShapeList[MAX_TENSOR_COUNT]; + LayoutA layoutAList[MAX_TENSOR_COUNT]; + LayoutB layoutBList[MAX_TENSOR_COUNT]; + LayoutC layoutWorkspaceList[MAX_TENSOR_COUNT]; + LayoutA layoutWAList[MAX_TENSOR_COUNT]; + LayoutB layoutWBList[MAX_TENSOR_COUNT]; + // Get matmul information from parameters + detail::UnpackListParam(problemShapeList, params.ptrProblemShape, params.problemCount); + detail::UnpackListParam(layoutAList, params.ptrLayoutA, params.problemCount); + detail::UnpackListParam(layoutBList, params.ptrLayoutB, params.problemCount); + detail::UnpackListParam(layoutWorkspaceList, params.ptrLayoutWorkspace, params.problemCount); + detail::UnpackListParam(layoutWAList, params.ptrLayoutWA, params.problemCount); + detail::UnpackListParam(layoutWBList, params.ptrLayoutWB, params.problemCount); + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + uint64_t inGroupOffsetA = 0; + uint64_t inGroupOffsetB = 0; + uint64_t inGroupOffsetWorkspace = 0; + uint32_t startCoreIdx = 0; + uint32_t startLoopIdx; + Arch::Resource resource; + BlockGemm blockGemm(resource); + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + GemmCoord problemShape = problemShapeList[groupIdx]; + LayoutA layoutA = layoutAList[groupIdx]; + LayoutB layoutB = layoutBList[groupIdx]; + LayoutC layoutWorkspace = layoutWorkspaceList[groupIdx]; + LayoutA layoutWA = layoutWAList[groupIdx]; + LayoutB layoutWB = layoutWBList[groupIdx]; + Arch::CrossCoreWaitFlag(flagAivFinishPadding); + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrWA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrWB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrWorkspace); + uint32_t M = problemShape.m(); + uint32_t N = problemShape.n(); + uint32_t K = problemShape.k(); +#pragma unroll + for (uint32_t i = 0; i < l0CBlockNum; i++) { + AscendC::SetFlag((int32_t)i); + } + uint32_t mLoops = CeilDiv(M, maxMPerBlock); + uint32_t nLoops = CeilDiv(N, maxNPerBlock); + uint32_t coreLoops = mLoops * nLoops; + // Determine the starting loopIdx of the current core under the current groupIdx + if (coreIdx < startCoreIdx) { + startLoopIdx = coreIdx + coreNum - startCoreIdx; + } else { + startLoopIdx = coreIdx - startCoreIdx; + } + uint32_t singleIdx = 0; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + uint32_t mGmBlockIdx = loopIdx / nLoops; + uint32_t nGmBlockIdx = loopIdx % nLoops; + uint32_t mGmActual = (mGmBlockIdx == mLoops - 1) ? (M - mGmBlockIdx * maxMPerBlock) : maxMPerBlock; + uint32_t nGmActual = (nGmBlockIdx == nLoops - 1) ? (N - nGmBlockIdx * maxNPerBlock) : maxNPerBlock; + bool isFirstBlock = (loopIdx == startLoopIdx); + bool hasNextBlock = false; + GemmCoord nextActualShape; + uint32_t mNextGmBlockIdx = 0; + uint32_t nNextGmBlockIdx = 0; + if (loopIdx + AscendC::GetBlockNum() < coreLoops) { + hasNextBlock = true; + uint32_t nextLoopIdx = loopIdx + AscendC::GetBlockNum(); + mNextGmBlockIdx = nextLoopIdx / nLoops; + nNextGmBlockIdx = nextLoopIdx % nLoops; + uint32_t mNextGmActual = + (mNextGmBlockIdx == mLoops - 1) ? (M - mNextGmBlockIdx * maxMPerBlock) : maxMPerBlock; + uint32_t nNextGmActual = + (nNextGmBlockIdx == nLoops - 1) ? (N - nNextGmBlockIdx * maxNPerBlock) : maxNPerBlock; + nextActualShape = MakeCoord(mNextGmActual, nNextGmActual, K); + } + GemmCoord actualShape{mGmActual, nGmActual, K}; + AscendC::WaitFlag((int32_t)singleIdx); + MatrixCoord gmTileAOffset{mGmBlockIdx * maxMPerBlock, 0}; + auto gmTileA = gmA[inGroupOffsetA + layoutWA.GetOffset(gmTileAOffset)]; + MatrixCoord gmTileBOffset{0, nGmBlockIdx * maxNPerBlock}; + auto gmTileB = gmB[inGroupOffsetB + layoutWB.GetOffset(gmTileBOffset)]; + MatrixCoord gmTileCOffset{mGmBlockIdx * maxMPerBlock, nGmBlockIdx * maxNPerBlock}; + auto gmTileC = gmC[inGroupOffsetWorkspace + layoutWorkspace.GetOffset(gmTileCOffset)]; + MatrixCoord gmTileNextAOffset{mNextGmBlockIdx * maxMPerBlock, 0}; + auto gmTileNextA = gmA[inGroupOffsetA + layoutWA.GetOffset(gmTileNextAOffset)]; + MatrixCoord gmTileNextBOffset{0, nNextGmBlockIdx * maxNPerBlock}; + auto gmTileNextB = gmB[inGroupOffsetB + layoutWB.GetOffset(gmTileNextBOffset)]; + blockGemm(gmTileA, layoutWA, gmTileB, layoutWB, gmTileC, layoutWorkspace, gmTileNextA, gmTileNextB, + actualShape, nextActualShape, isFirstBlock, hasNextBlock, singleIdx); + Arch::CrossCoreSetFlagWithReverse<0x2, PIPE_FIX>(flagAicFinishStore); + AscendC::SetFlag((int32_t)singleIdx); + singleIdx = (singleIdx + 1) % l0CBlockNum; + } + inGroupOffsetA += GetWorkspaceLen(layoutWA); + inGroupOffsetB += GetWorkspaceLen(layoutWB); + inGroupOffsetWorkspace += problemShape.m() * problemShape.n(); + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; +#pragma unroll + for (uint32_t i = 0; i < l0CBlockNum; i++) { + AscendC::WaitFlag((int32_t)i); + } + } + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params ¶ms) + { + GemmCoord problemShapeList[MAX_TENSOR_COUNT]; + LayoutA layoutAList[MAX_TENSOR_COUNT]; + LayoutB layoutBList[MAX_TENSOR_COUNT]; + LayoutC layoutWorkspaceList[MAX_TENSOR_COUNT]; + ElementScalar alphaList[MAX_TENSOR_COUNT]; + ElementScalar betaList[MAX_TENSOR_COUNT]; + LayoutA layoutWAList[MAX_TENSOR_COUNT]; + LayoutB layoutWBList[MAX_TENSOR_COUNT]; + detail::UnpackListParam(problemShapeList, params.ptrProblemShape, params.problemCount); + detail::UnpackListParam(layoutAList, params.ptrLayoutA, params.problemCount); + detail::UnpackListParam(layoutBList, params.ptrLayoutB, params.problemCount); + detail::UnpackListParam(layoutWorkspaceList, params.ptrLayoutWorkspace, params.problemCount); + detail::UnpackListParam(alphaList, params.alpha, params.problemCount); + detail::UnpackListParam(betaList, params.beta, params.problemCount); + detail::UnpackListParam(layoutWAList, params.ptrLayoutWA, params.problemCount); + detail::UnpackListParam(layoutWBList, params.ptrLayoutWB, params.problemCount); + uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t coreNum = AscendC::GetBlockNum(); + uint64_t inGroupOffsetA = 0; + uint64_t inGroupOffsetWA = 0; + uint64_t inGroupOffsetB = 0; + uint64_t inGroupOffsetWB = 0; + uint64_t inGroupOffsetWorkspace = 0; + uint32_t startCoreIdx = 0; + uint32_t startLoopIdx; + GemmCoord blockShape = L1TileShape::ToCoord(); + Arch::Resource resource; + AscendC::GlobalTensor gmA; + AscendC::GlobalTensor gmWA; + gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(params.ptrA)); + gmWA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(params.ptrWA)); + AscendC::GlobalTensor gmB; + AscendC::GlobalTensor gmWB; + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(params.ptrB)); + gmWB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(params.ptrWB)); + PaddingA paddingA(resource); + PaddingB paddingB(resource); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrWorkspace); + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + GemmCoord problemShape = problemShapeList[groupIdx]; + LayoutA layoutA = layoutAList[groupIdx]; + LayoutB layoutB = layoutBList[groupIdx]; + LayoutC layoutWorkspace = layoutWorkspaceList[groupIdx]; + ElementScalar alpha_ = alphaList[groupIdx]; + ElementScalar beta_ = betaList[groupIdx]; + LayoutA layoutWA = layoutWAList[groupIdx]; + LayoutB layoutWB = layoutWBList[groupIdx]; + paddingA(gmWA[inGroupOffsetWA], gmA[inGroupOffsetA], layoutWA, layoutA); + paddingB(gmWB[inGroupOffsetWB], gmB[inGroupOffsetB], layoutWB, layoutB); + Catlass::Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + Catlass::Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishPadding); + EpilogueParams epilogueParams{alpha_, beta_, params.ptrX, layoutWorkspace, params.ptrD, layoutWorkspace}; + BlockEpilogue blockEpilogue(resource, blockShape, epilogueParams); + uint32_t M = problemShape.m(); + uint32_t N = problemShape.n(); + uint32_t K = problemShape.k(); + uint32_t mLoops = CeilDiv(M, maxMPerBlock); + uint32_t nLoops = CeilDiv(N, maxNPerBlock); + uint32_t coreLoops = mLoops * nLoops; + if (coreIdx < startCoreIdx) { + startLoopIdx = coreIdx + coreNum - startCoreIdx; + } else { + startLoopIdx = coreIdx - startCoreIdx; + } + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + uint32_t mGmBlockIdx = loopIdx / nLoops; + uint32_t nGmBlockIdx = loopIdx % nLoops; + uint32_t mGmActual = (mGmBlockIdx == mLoops - 1) ? (M - mGmBlockIdx * maxMPerBlock) : maxMPerBlock; + uint32_t nGmActual = (nGmBlockIdx == nLoops - 1) ? (N - nGmBlockIdx * maxNPerBlock) : maxNPerBlock; + GemmCoord actualShape{mGmActual, nGmActual, K}; + GemmCoord blockCoord{mGmBlockIdx, nGmBlockIdx, 0}; + Arch::CrossCoreWaitFlagWithReverse<0x2, PIPE_MTE3>(flagAicFinishStore); + blockEpilogue(actualShape, blockCoord, gmC, layoutWorkspace, inGroupOffsetWorkspace); + } + inGroupOffsetA += problemShape.m() * problemShape.k(); + inGroupOffsetWA += GetWorkspaceLen(layoutWA); + inGroupOffsetB += problemShape.k() * problemShape.n(); + inGroupOffsetWB += GetWorkspaceLen(layoutWB); + inGroupOffsetWorkspace += problemShape.m() * problemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + AscendC::PipeBarrier(); + } + +private: + static constexpr Arch::FlagID FLAG_AIC_FINISH_STORE = 0; + static constexpr Arch::FlagID RV_FLAG_AIC_FINISH_STORE = 1; + Arch::CrossCoreFlagWithReverse<> flagAicFinishStore{FLAG_AIC_FINISH_STORE, RV_FLAG_AIC_FINISH_STORE}; + static constexpr Arch::FlagID FLAG_AIV_FINISH_STORE = 0; + Arch::CrossCoreFlag flagAivFinishPadding{FLAG_AIV_FINISH_STORE}; +}; +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_GROUPGEMM_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul.hpp new file mode 100644 index 00000000..b1e701c5 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul.hpp @@ -0,0 +1,207 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_GROUPED_MATMUL_HPP +#define CATLASS_GEMM_KERNEL_GROUPED_MATMUL_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemm::Kernel { + +namespace detail { + +template +CATLASS_DEVICE void UnpackListParam(T *const dst, GM_ADDR src, uint32_t len) +{ + for (uint32_t i = 0; i * sizeof(uint64_t) < len * sizeof(T); ++i) { + reinterpret_cast(dst)[i] = reinterpret_cast<__gm__ uint64_t *>(src)[i]; + } +} + +} // namespace detail + +// Template for grouped matmul kernel. Compute grouped C = A * B +template +class GroupedMatmul +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t MAX_TENSOR_COUNT = 256; + + /// Parameters structure + struct Params { + // Data members + uint32_t problemCount; + GM_ADDR ptrProblemShape; + GM_ADDR ptrA; + GM_ADDR ptrLayoutA; + GM_ADDR ptrB; + GM_ADDR ptrLayoutB; + GM_ADDR ptrC; + GM_ADDR ptrLayoutC; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(uint32_t problemCount_, GM_ADDR ptrProblemShape_, GM_ADDR ptrA_, GM_ADDR ptrLayoutA_, GM_ADDR ptrB_, + GM_ADDR ptrLayoutB_, GM_ADDR ptrC_, GM_ADDR ptrLayoutC_) + : problemCount(problemCount_), + ptrProblemShape(ptrProblemShape_), + ptrA(ptrA_), + ptrLayoutA(ptrLayoutA_), + ptrB(ptrB_), + ptrLayoutB(ptrLayoutB_), + ptrC(ptrC_), + ptrLayoutC(ptrLayoutC_) + {} + }; + + struct Arguments { + uint32_t problemCount; + uint8_t *ptrProblemShape; + uint8_t *ptrA; + uint8_t *ptrLayoutA; + uint8_t *ptrB; + uint8_t *ptrLayoutB; + uint8_t *ptrC; + uint8_t *ptrLayoutC; + }; + static bool CanImplement(const Arguments &args) + { + return true; + } + static size_t GetWorkspaceSize(const Arguments &args) + { + return 0; + } + static Params ToUnderlyingArguments(const Arguments &args, void *workspace) + { + Params params{args.problemCount, args.ptrProblemShape, args.ptrA, args.ptrLayoutA, + args.ptrB, args.ptrLayoutB, args.ptrC, args.ptrLayoutC}; + return params; + } + + // Methods + CATLASS_HOST_DEVICE + GroupedMatmul() {} + CATLASS_HOST_DEVICE + ~GroupedMatmul() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + /// Executes matmul + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + GemmCoord problemShapeList[MAX_TENSOR_COUNT]; + LayoutA layoutAList[MAX_TENSOR_COUNT]; + LayoutB layoutBList[MAX_TENSOR_COUNT]; + LayoutC layoutCList[MAX_TENSOR_COUNT]; + + // Get matmul information from parameters + detail::UnpackListParam(problemShapeList, params.ptrProblemShape, params.problemCount); + detail::UnpackListParam(layoutAList, params.ptrLayoutA, params.problemCount); + detail::UnpackListParam(layoutBList, params.ptrLayoutB, params.problemCount); + detail::UnpackListParam(layoutCList, params.ptrLayoutC, params.problemCount); + + BlockScheduler matmulBlockScheduler; + Arch::Resource resource; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrC); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t inGroupOffsetA = 0; + int64_t inGroupOffsetB = 0; + int64_t inGroupOffsetC = 0; + + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + GemmCoord problemShape = problemShapeList[groupIdx]; + LayoutA layoutA = layoutAList[groupIdx]; + LayoutB layoutB = layoutBList[groupIdx]; + LayoutC layoutC = layoutCList[groupIdx]; + + matmulBlockScheduler.Update(problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx; + if (coreIdx < startCoreIdx) { + startLoopIdx = coreIdx + coreNum - startCoreIdx; + } else { + startLoopIdx = coreIdx - startCoreIdx; + } + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockCoord); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + blockMmad(gmA[inGroupOffsetA + gmOffsetA], layoutA, gmB[inGroupOffsetB + gmOffsetB], layoutB, + gmC[inGroupOffsetC + gmOffsetC], layoutC, actualBlockShape); + } + + inGroupOffsetA += problemShape.m() * problemShape.k(); + inGroupOffsetB += problemShape.k() * problemShape.n(); + inGroupOffsetC += problemShape.m() * problemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + {} +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_GROUPED_MATMUL_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul_slice_k.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul_slice_k.hpp new file mode 100644 index 00000000..8ba7d97c --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul_slice_k.hpp @@ -0,0 +1,278 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_GROUPED_MATMUL_K_HPP +#define CATLASS_GEMM_KERNEL_GROUPED_MATMUL_K_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemm::Kernel { + +template +struct MemFill { +public: + using ArchTag = ArchTag_; + using Element = Element_; + + CATLASS_DEVICE + MemFill(Arch::Resource &resource) + { + ubBuffer = resource.ubBuf.template GetBufferByByte(0); + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &dst, uint32_t elementCount, Element fillValue) + { + const uint32_t maxBurstSize = MAX_BURST_BYTES / sizeof(Element); + const uint32_t ubBufferSize = ubBuffer.GetSize() > maxBurstSize ? maxBurstSize : ubBuffer.GetSize(); + const uint32_t batchCount = elementCount / ubBufferSize; + const uint32_t tailElements = elementCount % ubBufferSize; + + // duplicate fillValue to ubBuffer for datacopy later + AscendC::Duplicate(ubBuffer, fillValue, ubBufferSize); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + uint32_t currentOffset = 0; + + // fill the main block by datacopy + if (batchCount > 0) { + for (int index = 0; index < batchCount; ++index) { + AscendC::DataCopyPad( + dst[currentOffset], ubBuffer, + AscendC::DataCopyExtParams(1, static_cast(ubBufferSize * sizeof(Element)), 0, 0, 0)); + currentOffset += ubBufferSize; + } + } + + // fill the tail block by datacopy + if (tailElements != 0) { + AscendC::DataCopyPad( + dst[currentOffset], ubBuffer, + AscendC::DataCopyExtParams(1, static_cast(tailElements * sizeof(Element)), 0, 0, 0)); + } + } + + CATLASS_DEVICE + ~MemFill() {} + +private: + static const size_t MAX_BURST_BYTES = 255 * 32; + AscendC::LocalTensor ubBuffer; +}; + +// Template for grouped matmul kernel. Compute grouped C = A * B +template +class GroupedMatmulSliceK +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + using ElementGroupList = ElementGroupList_; + using MemFill0 = MemFill; + + using BlockScheduler = BlockScheduler_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementC *ptrC; + LayoutC layoutC; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord const &problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, + LayoutA const &layoutA_, GM_ADDR ptrB_, LayoutB const &layoutB_, GM_ADDR ptrC_, LayoutC const &layoutC_) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrC(reinterpret_cast<__gm__ ElementC *>(ptrC_)), + layoutC(layoutC_) + {} + }; + + struct Arguments { + GemmCoord problemShape; + uint32_t problemCount; + uint8_t *ptrGroupList; + uint8_t *ptrA; + uint8_t *ptrB; + uint8_t *ptrC; + }; + static bool CanImplement(const Arguments &args) + { + return true; + } + static size_t GetWorkspaceSize(const Arguments &args) + { + return 0; + } + static Params ToUnderlyingArguments(const Arguments &args, void *workspace) + { + uint32_t m = args.problemShape.m(); + uint32_t n = args.problemShape.n(); + uint32_t k = args.problemShape.k(); + LayoutA layoutA{m, k}; + LayoutB layoutB{k, n}; + LayoutC layoutC{m, n}; + Params params{args.problemShape, args.problemCount, args.ptrGroupList, args.ptrA, layoutA, + args.ptrB, layoutB, args.ptrC, layoutC}; + return params; + } + + // Methods + CATLASS_DEVICE + GroupedMatmulSliceK() {} + + CATLASS_DEVICE + ~GroupedMatmulSliceK() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + /// Executes matmul + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrC); + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t inGroupOffsetA = 0; + int64_t inGroupOffsetB = 0; + int64_t inGroupOffsetC = 0; + + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentK = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord problemShape{params.problemShape.m(), params.problemShape.n(), currentK}; + + if (currentK == 0) { + inGroupOffsetA += problemShape.m() * problemShape.k(); + inGroupOffsetB += problemShape.k() * problemShape.n(); + inGroupOffsetC += problemShape.m() * problemShape.n(); + continue; + } + + LayoutA layoutA = params.layoutA.GetTileLayout(problemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB.GetTileLayout(problemShape.GetCoordKN()); + LayoutC layoutC = params.layoutC; + + blockScheduler.Update(problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx; + if (coreIdx < startCoreIdx) { + startLoopIdx = coreIdx + coreNum - startCoreIdx; + } else { + startLoopIdx = coreIdx - startCoreIdx; + } + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + blockMmad(gmA[inGroupOffsetA + gmOffsetA], layoutA, gmB[inGroupOffsetB + gmOffsetB], layoutB, + gmC[inGroupOffsetC + gmOffsetC], layoutC, actualBlockShape); + } + + inGroupOffsetA += problemShape.m() * problemShape.k(); + inGroupOffsetB += problemShape.k() * problemShape.n(); + inGroupOffsetC += problemShape.m() * problemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + MemFill0 memFill0(resource); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrC); + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + int64_t inGroupOffsetC = 0; + + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentK = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord problemShape{params.problemShape.m(), params.problemShape.n(), currentK}; + + if (currentK == 0) { + memFill0(gmC[inGroupOffsetC], problemShape.m() * problemShape.n(), 0); + } + inGroupOffsetC += problemShape.m() * problemShape.n(); + } + AscendC::PipeBarrier(); + } + +private: + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_GROUPED_MATMUL_K_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul_slice_k_per_token_dequant.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul_slice_k_per_token_dequant.hpp new file mode 100644 index 00000000..288c2977 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul_slice_k_per_token_dequant.hpp @@ -0,0 +1,328 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_GROUPED_MATMUL_K_PER_TOKEN_DEQUANT_HPP +#define CATLASS_GEMM_KERNEL_GROUPED_MATMUL_K_PER_TOKEN_DEQUANT_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemm::Kernel { + +template +class GroupedMatmulSliceKPerTokenDequant +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using ElementGroupList = ElementGroupList_; + + using BlockScheduler = BlockScheduler_; + + friend class AicFinishSync; + friend class AivWaitSync; + + struct AicFinishSync { + using MatmulKernel = + GroupedMatmulSliceKPerTokenDequant; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreSetFlagWithReverse<0x2, PIPE_FIX>(ptr->flagAicFinishStore); + } + + MatmulKernel *ptr; + }; + + struct AivWaitSync { + using MatmulKernel = + GroupedMatmulSliceKPerTokenDequant; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreWaitFlagWithReverse<0x2, PIPE_MTE3>(ptr->flagAicFinishStore); + } + + MatmulKernel *ptr; + }; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementD *ptrD; + LayoutD layoutD; + GM_ADDR ptrWorkspace; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, LayoutA layoutA_, + GM_ADDR ptrB_, LayoutB layoutB_, GM_ADDR ptrScale_, LayoutScale layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale layoutPerTokenScale_, GM_ADDR ptrD_, LayoutD layoutD_, GM_ADDR ptrWorkspace_) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(reinterpret_cast<__gm__ ElementD *>(ptrD_)), + layoutD(layoutD_), + ptrWorkspace(ptrWorkspace_) + {} + }; + + struct Arguments { + GemmCoord problemShape; + uint32_t problemCount; + uint8_t *ptrGroupList; + uint8_t *ptrA; + uint8_t *ptrB; + uint8_t *ptrScale; + uint8_t *ptrPerTokenScale; + uint8_t *ptrD; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + uint32_t m = args.problemShape.m(); + uint32_t n = args.problemShape.n(); + size_t lenD = static_cast(m) * n * args.problemCount; + size_t lenWorkspace = lenD; + size_t sizeWorkspace = lenWorkspace * sizeof(uint32_t); + return sizeWorkspace; + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + uint32_t m = args.problemShape.m(); + uint32_t n = args.problemShape.n(); + uint32_t k = args.problemShape.k(); + LayoutA layoutA{m, k}; + LayoutB layoutB{k, n}; + LayoutScale layoutScale{n}; + LayoutPerTokenScale layoutPerTokenScale{m}; + LayoutD layoutD{m, n}; + Params params{args.problemShape, args.problemCount, args.ptrGroupList, args.ptrA, layoutA, + args.ptrB, layoutB, args.ptrScale, layoutScale, args.ptrPerTokenScale, + layoutPerTokenScale, args.ptrD, layoutD, workspace}; + return params; + } + + // Methods + CATLASS_DEVICE + GroupedMatmulSliceKPerTokenDequant() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer(params.ptrB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + int64_t gmGroupOffsetC = 0; + + AicFinishSync aicFinishSync{this}; + + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentK = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{params.problemShape.m(), params.problemShape.n(), currentK}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB.GetTileLayout(inGroupProblemShape.GetCoordKN()); + LayoutC layoutC = LayoutC(inGroupProblemShape.m(), inGroupProblemShape.n()); + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmGroupOffsetC + gmOffsetC], layoutC, actualBlockShape, MakeCallback(&aicFinishSync)); + } else { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmGroupOffsetC + gmOffsetC], layoutC, actualBlockShape); + aicFinishSync(); + } + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource); + + uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetC = 0; + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + AivWaitSync aicFinishSync{this}; + + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentK = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{params.problemShape.m(), params.problemShape.n(), currentK}; + + LayoutC layoutC = LayoutC(inGroupProblemShape.m(), inGroupProblemShape.n()); + LayoutScale layoutScale = params.layoutScale; + LayoutPerTokenScale layoutPerTokenScale = + params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN()); + + EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, + layoutScale, + params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + params.ptrD + gmGroupOffsetD, + layoutD}; + + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + int64_t gmInGroupOffsetC = layoutC.GetOffset(blockCoordMNK.GetCoordMN() * blockShapeMNK.GetCoordMN()); + auto gmBlockC = gmC[gmGroupOffsetC + gmInGroupOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + + blockEpilogue(blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, layoutBlockC, + MakeCallback(&aicFinishSync)); + } + + gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); + gmGroupOffsetScale += inGroupProblemShape.n(); + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += inGroupProblemShape.m() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + AscendC::PipeBarrier(); + } + +private: + static constexpr Arch::FlagID FLAG_AIC_FINISH_STORE = 0; + static constexpr Arch::FlagID RV_FLAG_AIC_FINISH_STORE = 1; + Arch::CrossCoreFlagWithReverse<> flagAicFinishStore{FLAG_AIC_FINISH_STORE, RV_FLAG_AIC_FINISH_STORE}; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_GROUPED_MATMUL_K_PER_TOKEN_DEQUANT_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul_slice_m.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul_slice_m.hpp new file mode 100644 index 00000000..423548d8 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul_slice_m.hpp @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_GROUPED_MATMUL_M_HPP +#define CATLASS_GEMM_KERNEL_GROUPED_MATMUL_M_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +inline __gm__ struct OpSystemRunCfg g_opSystemRunCfg { + Catlass::L2_OFFSET +}; + +namespace Catlass::Gemm::Kernel { + +// Template for grouped matmul kernel. Compute grouped C = A * B +template +class GroupedMatmulSliceM +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using ElementGroupList = ElementGroupList_; + + using BlockScheduler = BlockScheduler_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementC *ptrC; + LayoutC layoutC; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord const &problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, + LayoutA const &layoutA_, GM_ADDR ptrB_, LayoutB const &layoutB_, GM_ADDR ptrC_, LayoutC const &layoutC_) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrC(reinterpret_cast<__gm__ ElementC *>(ptrC_)), + layoutC(layoutC_) + {} + }; + struct Arguments { + GemmCoord problemShape; + uint32_t problemCount; + uint8_t *ptrGroupList; + uint8_t *ptrA; + uint8_t *ptrB; + uint8_t *ptrC; + }; + static bool CanImplement(const Arguments &args) + { + return true; + } + static size_t GetWorkspaceSize(const Arguments &args) + { + return 0; + } + static Params ToUnderlyingArguments(const Arguments &args, void *workspace) + { + uint32_t m = args.problemShape.m(); + uint32_t n = args.problemShape.n(); + uint32_t k = args.problemShape.k(); + LayoutA layoutA{m, k}; + LayoutB layoutB{k, n}; + LayoutC layoutC{m, n}; + Params params{args.problemShape, args.problemCount, args.ptrGroupList, args.ptrA, layoutA, + args.ptrB, layoutB, args.ptrC, layoutC}; + return params; + } + // Methods + CATLASS_HOST_DEVICE + GroupedMatmulSliceM() {} + // Methods + CATLASS_HOST_DEVICE + ~GroupedMatmulSliceM() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + Arch::Resource resource; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(params.ptrC); + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + int64_t gmGroupOffsetC = 0; + + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + LayoutC layoutC = params.layoutC.GetTileLayout(inGroupProblemShape.GetCoordMN()); + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer(params.ptrB + gmGroupOffsetB); + if (CeilDiv(currentM, L1TileShape::M) == 1) { + gmB.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); + } + + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx; + if (coreIdx < startCoreIdx) { + startLoopIdx = coreIdx + coreNum - startCoreIdx; + } else { + startLoopIdx = coreIdx - startCoreIdx; + } + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmOffsetB], layoutB, + gmC[gmGroupOffsetC + gmOffsetC], layoutC, actualBlockShape); + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + {} +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_GROUPED_MATMUL_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul_slice_m_per_token_dequant.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul_slice_m_per_token_dequant.hpp new file mode 100644 index 00000000..18fde55d --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul_slice_m_per_token_dequant.hpp @@ -0,0 +1,331 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_HPP +#define CATLASS_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemm::Kernel { + +template +class GroupedMatmulSliceMPerTokenDequant +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using ElementGroupList = ElementGroupList_; + + using BlockScheduler = BlockScheduler_; + + friend class AicFinishSync; + friend class AivWaitSync; + + struct AicFinishSync { + using MatmulKernel = + GroupedMatmulSliceMPerTokenDequant; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreSetFlagWithReverse<0x2, PIPE_FIX>(ptr->flagAicFinishStore); + } + + MatmulKernel *ptr; + }; + + struct AivWaitSync { + using MatmulKernel = + GroupedMatmulSliceMPerTokenDequant; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreWaitFlagWithReverse<0x2, PIPE_MTE3>(ptr->flagAicFinishStore); + } + + MatmulKernel *ptr; + }; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementD *ptrD; + LayoutD layoutD; + GM_ADDR ptrWorkspace; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, LayoutA layoutA_, + GM_ADDR ptrB_, LayoutB layoutB_, GM_ADDR ptrScale_, LayoutScale layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale layoutPerTokenScale_, GM_ADDR ptrD_, LayoutD layoutD_, GM_ADDR ptrWorkspace_) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(reinterpret_cast<__gm__ ElementD *>(ptrD_)), + layoutD(layoutD_), + ptrWorkspace(ptrWorkspace_) + {} + }; + + struct Arguments { + // + // Data members + // + GemmCoord problemShape; + uint32_t problemCount; + uint8_t *ptrGroupList; + uint8_t *ptrA; + uint8_t *ptrB; + uint8_t *ptrScale; + uint8_t *ptrPerTokenScale; + uint8_t *ptrD; + }; + static bool CanImplement(const Arguments &args) + { + return true; + } + static size_t GetWorkspaceSize(const Arguments &args) + { + uint32_t m = args.problemShape.m(); + uint32_t n = args.problemShape.n(); + size_t lenD = static_cast(m) * n; + size_t lenWorkspace = lenD; + size_t sizeWorkspace = lenWorkspace * sizeof(ElementGroupList); + return sizeWorkspace; + } + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + uint32_t m = args.problemShape.m(); + uint32_t n = args.problemShape.n(); + uint32_t k = args.problemShape.k(); + + LayoutA layoutA{m, k}; + LayoutB layoutB{k, n}; + LayoutScale layoutScale{n}; + LayoutPerTokenScale layoutPerTokenScale{m}; + LayoutD layoutD{m, n}; + + Params params{args.problemShape, args.problemCount, args.ptrGroupList, args.ptrA, layoutA, + args.ptrB, layoutB, args.ptrScale, layoutScale, args.ptrPerTokenScale, + layoutPerTokenScale, args.ptrD, layoutD, workspace}; + return params; + } + + // Methods + CATLASS_DEVICE + GroupedMatmulSliceMPerTokenDequant() {} + + CATLASS_DEVICE + ~GroupedMatmulSliceMPerTokenDequant() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer(params.ptrB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + int64_t gmGroupOffsetC = 0; + + AicFinishSync aicFinishSync{this}; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + LayoutC layoutC = LayoutC(inGroupProblemShape.m(), inGroupProblemShape.n()); + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmGroupOffsetC + gmOffsetC], layoutC, actualBlockShape, MakeCallback(&aicFinishSync)); + } else { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmGroupOffsetC + gmOffsetC], layoutC, actualBlockShape); + aicFinishSync(); + } + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource); + + uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetC = 0; + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + AivWaitSync aicFinishSync{this}; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutC layoutC = LayoutC(inGroupProblemShape.m(), inGroupProblemShape.n()); + LayoutScale layoutScale = params.layoutScale; + LayoutPerTokenScale layoutPerTokenScale = + params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN()); + + EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, + layoutScale, + params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + params.ptrD + gmGroupOffsetD, + layoutD}; + + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + int64_t gmInGroupOffsetC = layoutC.GetOffset(blockCoordMNK.GetCoordMN() * blockShapeMNK.GetCoordMN()); + auto gmBlockC = gmC[gmGroupOffsetC + gmInGroupOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + + blockEpilogue(blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, layoutBlockC, + MakeCallback(&aicFinishSync)); + } + + gmGroupOffsetC += inGroupProblemShape.m() * inGroupProblemShape.n(); + gmGroupOffsetScale += inGroupProblemShape.n(); + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += inGroupProblemShape.m() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + AscendC::PipeBarrier(); + } + +private: + static constexpr Arch::FlagID FLAG_AIC_FINISH_STORE = 0; + static constexpr Arch::FlagID RV_FLAG_AIC_FINISH_STORE = 1; + Arch::CrossCoreFlagWithReverse<> flagAicFinishStore{FLAG_AIC_FINISH_STORE, RV_FLAG_AIC_FINISH_STORE}; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp new file mode 100644 index 00000000..ead6b4a4 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp @@ -0,0 +1,376 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP +#define CATLASS_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +extern __gm__ struct OpSystemRunCfg g_opSystemRunCfg; + +namespace Catlass::Gemm::Kernel { + +template +class GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; + using ElementGroupList = ElementGroupList_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList_ *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementD *ptrD; + LayoutD layoutD; + GM_ADDR ptrWorkspace; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, LayoutA layoutA_, + GM_ADDR ptrB_, LayoutB layoutB_, GM_ADDR ptrScale_, LayoutScale layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale layoutPerTokenScale_, GM_ADDR ptrD_, LayoutD layoutD_, GM_ADDR ptrWorkspace_) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(reinterpret_cast<__gm__ ElementD *>(ptrD_)), + layoutD(layoutD_), + ptrWorkspace(ptrWorkspace_) + {} + }; + + struct Arguments { + GemmCoord problemShape; + uint32_t problemCount; + uint32_t aicCoreNum; + uint8_t *ptrGroupList; + uint8_t *ptrA; + uint8_t *ptrB; + uint8_t *ptrScale; + uint8_t *ptrPerTokenScale; + uint8_t *ptrD; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + size_t lenWorkspace = static_cast(L1TileShape::M) * L1TileShape::N * args.aicCoreNum * WORKSPACE_STAGES; + size_t sizeWorkspace = lenWorkspace * sizeof(uint32_t); + return sizeWorkspace; + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + uint32_t m = args.problemShape.m(); + uint32_t n = args.problemShape.n(); + uint32_t k = args.problemShape.k(); + LayoutA layoutA{m, k}; + LayoutB layoutB{k, n}; + LayoutScale layoutScale{n}; + LayoutPerTokenScale layoutPerTokenScale{m}; + LayoutD layoutD{m, n}; + Params params{args.problemShape, args.problemCount, args.ptrGroupList, args.ptrA, layoutA, + args.ptrB, layoutB, args.ptrScale, layoutScale, args.ptrPerTokenScale, + layoutPerTokenScale, args.ptrD, layoutD, workspace}; + return params; + } + + // Methods + CATLASS_DEVICE + GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace() + { + Arch::FlagID flagId = 0; + for (uint32_t stageId = 0; stageId < WORKSPACE_STAGES; ++stageId) { + flagAicFinishStoreList[stageId] = Arch::CrossCoreFlag(flagId++); + flagAivFinishComputeList[stageId] = Arch::CrossCoreFlag(flagId++); + aicWaitFuncList[stageId] = {this, stageId}; + aicSetFuncList[stageId] = {this, stageId}; + } + } + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + AscendC::ICachePreLoad(1); + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t stageUsed = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer(params.ptrB + gmGroupOffsetB); + if (CeilDiv(currentM, L1TileShape::M) == 1) { + gmB.SetL2CacheHint(AscendC::CacheMode::CACHE_MODE_DISABLE); + } + + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + Callback callbackBeforeFixpipe{}; + if (stageUsed == WORKSPACE_STAGES) { + callbackBeforeFixpipe = MakeCallback(&aicWaitFuncList[stageId]); + } else { + ++stageUsed; + } + Callback callbackAfterFixpipe = MakeCallback(&aicSetFuncList[stageId]); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmOffsetB], layoutB, gmC[gmOffsetC], + layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe); + } else { + callbackBeforeFixpipe(); + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmOffsetB], layoutB, gmC[gmOffsetC], + layoutC, actualBlockShape); + callbackAfterFixpipe(); + } + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + while (stageUsed > 0) { + uint32_t aivComputeStageId = + (stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed); + Arch::CrossCoreWaitFlag(flagAivFinishComputeList[aivComputeStageId]); + --stageUsed; + } + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + AscendC::ICachePreLoad(1); + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource); + + uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutScale layoutScale = params.layoutScale; + LayoutPerTokenScale layoutPerTokenScale = + params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN()); + + EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, + layoutScale, + params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + params.ptrD + gmGroupOffsetD, + layoutD}; + + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + auto gmBlockC = gmC[gmOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + + Arch::CrossCoreWaitFlag(flagAicFinishStoreList[stageId]); + blockEpilogue(blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, layoutBlockC); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishComputeList[stageId]); + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetScale += inGroupProblemShape.n(); + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += inGroupProblemShape.m() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + AscendC::PipeBarrier(); + } + +private: + friend struct AicWaitFunc; + friend struct AicSetFunc; + + struct AicWaitFunc { + using MatmulKernel = + GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace; + + CATLASS_DEVICE + AicWaitFunc() = default; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreWaitFlag(ptr->flagAivFinishComputeList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + struct AicSetFunc { + using MatmulKernel = + GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace; + + CATLASS_DEVICE + AicSetFunc() = default; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(ptr->flagAicFinishStoreList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + Arch::CrossCoreFlag flagAicFinishStoreList[WORKSPACE_STAGES]; + Arch::CrossCoreFlag flagAivFinishComputeList[WORKSPACE_STAGES]; + + AicWaitFunc aicWaitFuncList[WORKSPACE_STAGES]; + AicSetFunc aicSetFuncList[WORKSPACE_STAGES]; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/matmul_activation.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/matmul_activation.hpp new file mode 100644 index 00000000..c8b56d86 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/matmul_activation.hpp @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_MATMUL_EPILOGUE_HPP +#define CATLASS_GEMM_KERNEL_MATMUL_EPILOGUE_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemm::Kernel { + +// Template for matmul add kernel. Compute C(fp32) = A * B, D = Cast(Activation(C)) +template +class MatmulActivation +{ +public: + // BlockMmad的内核 + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + + // 后处理的内核 + using BlockEpilogue = BlockEpilogue_; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using BlockScheduler = BlockScheduler_; + + static_assert(std::is_same_v && + std::is_same_v, + "The CType of Mmad and Epilogue should be consistent."); + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + GM_ADDR ptrA; + LayoutA layoutA; + GM_ADDR ptrB; + LayoutB layoutB; + GM_ADDR ptrWorkspace; + EpilogueParams epilogueParams; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA const &layoutA_, GM_ADDR ptrB_, + LayoutB const &layoutB_, GM_ADDR ptrWorkspace_, EpilogueParams const &epilogueParams_) + : problemShape(problemShape_), + ptrA(ptrA_), + layoutA(layoutA_), + ptrB(ptrB_), + layoutB(layoutB_), + ptrWorkspace(ptrWorkspace_), + epilogueParams(epilogueParams_) + {} + }; + + struct Arguments { + GemmCoord problemShape; + size_t elementSize; + GM_ADDR ptrA; + GM_ADDR ptrB; + GM_ADDR ptrC; + GM_ADDR ptrD; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + return args.elementSize * args.problemShape.m() * args.problemShape.n(); + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + GemmCoord problemShape = args.problemShape; + uint32_t m = problemShape.m(); + uint32_t n = problemShape.n(); + uint32_t k = problemShape.k(); + LayoutA layoutA{m, k}; + LayoutB layoutB{k, n}; + LayoutC layoutC{m, n}; + // 传出 + typename BlockEpilogue::Params epilogueParams{args.ptrC, layoutC, args.ptrD, layoutC}; + Params params{problemShape, args.ptrA, layoutA, // A矩阵 + args.ptrB, layoutB, // B矩阵 + args.ptrC, epilogueParams}; + return params; + } + + // Methods + CATLASS_DEVICE + MatmulActivation() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler matmulBlockScheduler(params.problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops(); + + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrWorkspace); + layout::RowMajor layoutC(params.problemShape.m(), params.problemShape.n()); + + for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + // Compute block location + GemmCoord blockCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockCoord); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); + int64_t gmOffsetB = params.layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + blockMmad(gmA[gmOffsetA], params.layoutA, gmB[gmOffsetB], params.layoutB, gmC[gmOffsetC], layoutC, + actualBlockShape); + + Arch::CrossCoreSetFlagWithReverse<0x2, PIPE_FIX>(flagAicFinishStore); + } + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler matmulBlockScheduler(params.problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops(); + + BlockEpilogue blockEpilogue(resource, params.epilogueParams); + + // Represent the full gm + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrWorkspace); + layout::RowMajor layoutC(params.problemShape.m(), params.problemShape.n()); + + // Get aicore information + uint32_t aicoreIndex = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t aicoreNum = AscendC::GetBlockNum(); + uint32_t subcoreIndex = AscendC::GetSubBlockIdx(); + + // Loop through the epilogue calculations of each basic block + GemmCoord blockShape = L1TileShape::ToCoord(); + for (uint32_t loopIdx = aicoreIndex; loopIdx < coreLoops; loopIdx += aicoreNum) { + // Compute block location + GemmCoord blockCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockCoord); + // Get the data and layout of C under the current basic block + auto gmBlockC = gmC[layoutC.GetOffset(blockCoord.GetCoordMN() * blockShape.GetCoordMN())]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShape.GetCoordMN()); + // Synchronize cross core + Arch::CrossCoreWaitFlagWithReverse<0x2, PIPE_MTE3>(flagAicFinishStore); + // Actual calculatioin logic for performing block-scoped epilogue + blockEpilogue(blockShape, blockCoord, actualBlockShape, gmBlockC, layoutBlockC); + } + + AscendC::PipeBarrier(); + } + +private: + // ID used for inter-core synchronization + static constexpr Arch::FlagID FLAG_AIC_FINISH_STORE = 0; + static constexpr Arch::FlagID RV_FLAG_AIC_FINISH_STORE = 1; + Arch::CrossCoreFlagWithReverse<> flagAicFinishStore{FLAG_AIC_FINISH_STORE, RV_FLAG_AIC_FINISH_STORE}; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_MATMUL_EPILOGUE_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/matmul_bias.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/matmul_bias.hpp new file mode 100644 index 00000000..10b7ec97 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/matmul_bias.hpp @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_MATMUL_BIAS_HPP +#define CATLASS_GEMM_KERNEL_MATMUL_BIAS_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemm::Kernel { + +// Template for matmul add kernel. Compute D = A * B + X +template +class MatmulBias +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementBias = typename BlockMmad::ElementBias; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockScheduler = BlockScheduler_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + GM_ADDR ptrA; + LayoutA layoutA; + GM_ADDR ptrB; + LayoutB layoutB; + GM_ADDR ptrC; + LayoutC layoutC; + GM_ADDR ptrBias; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA const &layoutA_, GM_ADDR ptrB_, + LayoutB const &layoutB_, GM_ADDR ptrC_, LayoutC const &layoutC_, GM_ADDR ptrBias_) + : problemShape(problemShape_), + ptrA(ptrA_), + layoutA(layoutA_), + ptrB(ptrB_), + layoutB(layoutB_), + ptrC(ptrC_), + layoutC(layoutC_), + ptrBias(ptrBias_) + {} + }; + + struct Arguments { + GemmCoord problemShape; + GM_ADDR ptrA; + GM_ADDR ptrB; + GM_ADDR ptrC; + GM_ADDR ptrBias; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + return 0; + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + LayoutA layoutA{args.problemShape.m(), args.problemShape.k()}; + LayoutB layoutB{args.problemShape.k(), args.problemShape.n()}; + LayoutC layoutC{args.problemShape.m(), args.problemShape.n()}; + Params params{args.problemShape, args.ptrA, layoutA, args.ptrB, layoutB, args.ptrC, layoutC, args.ptrBias}; + return params; + } + + // Methods + CATLASS_DEVICE + MatmulBias() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler matmulBlockScheduler(params.problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops(); + + Arch::Resource resource; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrC); + AscendC::GlobalTensor gmBias; + gmBias.SetGlobalBuffer((__gm__ ElementBias *)params.ptrBias); + + for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + // Compute block location + GemmCoord blockCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockCoord); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); + int64_t gmOffsetB = params.layoutB.GetOffset(offsetB); + int64_t gmOffsetC = params.layoutC.GetOffset(offsetC); + int64_t gmOffsetBias = blockCoord.n() * L1TileShape::N; + + // Compute block-scoped matrix multiply-add + blockMmad(gmA[gmOffsetA], params.layoutA, gmB[gmOffsetB], params.layoutB, gmC[gmOffsetC], params.layoutC, + gmBias[gmOffsetBias], actualBlockShape); + } + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + {} +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_MATMUL_BIAS_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/matmul_epilogue.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/matmul_epilogue.hpp new file mode 100644 index 00000000..cb96dcaa --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/matmul_epilogue.hpp @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_MATMUL_EPILOGUE_HPP +#define CATLASS_GEMM_KERNEL_MATMUL_EPILOGUE_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemm::Kernel { + +// Template for matmul add kernel. Compute D = A * B + X +template +class MatmulEpilogue +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + + using BlockEpilogue = BlockEpilogue_; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using BlockScheduler = BlockScheduler_; + + static_assert(std::is_same_v && + std::is_same_v, + "The CType of Mmad and Epilogue should be consistent."); + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + GM_ADDR ptrA; + LayoutA layoutA; + GM_ADDR ptrB; + LayoutB layoutB; + GM_ADDR ptrWorkspace; + EpilogueParams epilogueParams; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA const &layoutA_, GM_ADDR ptrB_, + LayoutB const &layoutB_, GM_ADDR ptrWorkspace_, EpilogueParams const &epilogueParams_) + : problemShape(problemShape_), + ptrA(ptrA_), + layoutA(layoutA_), + ptrB(ptrB_), + layoutB(layoutB_), + ptrWorkspace(ptrWorkspace_), + epilogueParams(epilogueParams_) + {} + }; + + struct Arguments { + GemmCoord problemShape; + size_t elementSize; + GM_ADDR ptrA; + GM_ADDR ptrB; + GM_ADDR ptrC; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + return args.elementSize * args.problemShape.m() * args.problemShape.n(); + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + GemmCoord problemShape = args.problemShape; + uint32_t m = problemShape.m(); + uint32_t n = problemShape.n(); + uint32_t k = problemShape.k(); + LayoutA layoutA{m, k}; + LayoutB layoutB{k, n}; + LayoutC layoutC{m, n}; + typename BlockEpilogue::Params epilogueParams{args.ptrC, layoutC, args.ptrC, layoutC}; + Params params{problemShape, args.ptrA, layoutA, args.ptrB, layoutB, workspace, epilogueParams}; + return params; + } + + // Methods + CATLASS_DEVICE + MatmulEpilogue() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler matmulBlockScheduler(params.problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops(); + + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrWorkspace); + layout::RowMajor layoutC(params.problemShape.m(), params.problemShape.n()); + + for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + // Compute block location + GemmCoord blockCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockCoord); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); + int64_t gmOffsetB = params.layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + blockMmad(gmA[gmOffsetA], params.layoutA, gmB[gmOffsetB], params.layoutB, gmC[gmOffsetC], layoutC, + actualBlockShape); + + Arch::CrossCoreSetFlagWithReverse<0x2, PIPE_FIX>(flagAicFinishStore); + } + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler matmulBlockScheduler(params.problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops(); + + BlockEpilogue blockEpilogue(resource, params.epilogueParams); + + // Represent the full gm + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrWorkspace); + layout::RowMajor layoutC(params.problemShape.m(), params.problemShape.n()); + + // Get aicore information + uint32_t aicoreIndex = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t aicoreNum = AscendC::GetBlockNum(); + uint32_t subcoreIndex = AscendC::GetSubBlockIdx(); + + // Loop through the epilogue calculations of each basic block + GemmCoord blockShape = L1TileShape::ToCoord(); + for (uint32_t loopIdx = aicoreIndex; loopIdx < coreLoops; loopIdx += aicoreNum) { + // Compute block location + GemmCoord blockCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockCoord); + // Get the data and layout of C under the current basic block + auto gmBlockC = gmC[layoutC.GetOffset(blockCoord.GetCoordMN() * blockShape.GetCoordMN())]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShape.GetCoordMN()); + // Synchronize cross core + Arch::CrossCoreWaitFlagWithReverse<0x2, PIPE_MTE3>(flagAicFinishStore); + // Actual calculatioin logic for performing block-scoped epilogue + blockEpilogue(blockShape, blockCoord, actualBlockShape, gmBlockC, layoutBlockC); + } + + AscendC::PipeBarrier(); + } + +private: + // ID used for inter-core synchronization + static constexpr Arch::FlagID FLAG_AIC_FINISH_STORE = 0; + static constexpr Arch::FlagID RV_FLAG_AIC_FINISH_STORE = 1; + Arch::CrossCoreFlagWithReverse<> flagAicFinishStore{FLAG_AIC_FINISH_STORE, RV_FLAG_AIC_FINISH_STORE}; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_MATMUL_EPILOGUE_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/matmul_full_loadA.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/matmul_full_loadA.hpp new file mode 100644 index 00000000..c4477ba2 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/matmul_full_loadA.hpp @@ -0,0 +1,163 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_MATMUL_FULL_LOADA_HPP +#define CATLASS_GEMM_KERNEL_MATMUL_FULL_LOADA_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemm::Kernel { + +// Template for Matmul kernel. Compute C = A * B +template +class MatmulFullLoadA +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockScheduler = BlockScheduler_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + GM_ADDR ptrA; + LayoutA layoutA; + GM_ADDR ptrB; + LayoutB layoutB; + GM_ADDR ptrC; + LayoutC layoutC; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, + GM_ADDR ptrC_, LayoutC layoutC_) + : problemShape(problemShape_), + ptrA(ptrA_), + layoutA(layoutA_), + ptrB(ptrB_), + layoutB(layoutB_), + ptrC(ptrC_), + layoutC(layoutC_) + {} + }; + + struct Arguments { + GemmCoord problemShape; + GM_ADDR ptrA; + GM_ADDR ptrB; + GM_ADDR ptrC; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + return 0; + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + LayoutA layoutA{args.problemShape.m(), args.problemShape.k()}; + LayoutB layoutB{args.problemShape.k(), args.problemShape.n()}; + LayoutC layoutC{args.problemShape.m(), args.problemShape.n()}; + Params params{args.problemShape, args.ptrA, layoutA, args.ptrB, layoutB, args.ptrC, layoutC}; + return params; + } + + // Methods + CATLASS_DEVICE + MatmulFullLoadA() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + /// Executes one Matmul + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler matmulBlockScheduler(params.problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops(); + + Arch::Resource resource; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrC); + + int64_t gmOffsetAPreload{0}; + uint32_t firstBlockIdx = AscendC::GetBlockIdx(); + + for (uint32_t loopIdx = firstBlockIdx; loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + // Compute block location + GemmCoord blockCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockCoord); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); + int64_t gmOffsetB = params.layoutB.GetOffset(offsetB); + int64_t gmOffsetC = params.layoutC.GetOffset(offsetC); + + // Judge whether the current blockA is already on L1Cache + bool isFirstBlock = (loopIdx == firstBlockIdx); + bool needLoadL1 = true; + if (isFirstBlock) { + gmOffsetAPreload = gmOffsetA; + } else { + if (gmOffsetA == gmOffsetAPreload) { + needLoadL1 = false; + } else { + gmOffsetAPreload = gmOffsetA; + } + } + + // Compute block-scoped matrix multiply-add + blockMmad(gmA[gmOffsetA], params.layoutA, gmB[gmOffsetB], params.layoutB, gmC[gmOffsetC], params.layoutC, + actualBlockShape, needLoadL1); + } + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + {} +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_MATMUL_FULL_LOADA_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/optimized_matmul.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/optimized_matmul.hpp new file mode 100644 index 00000000..f51f2531 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/optimized_matmul.hpp @@ -0,0 +1,397 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_OPTIMIZED_MATMUL_HPP +#define CATLASS_GEMM_KERNEL_OPTIMIZED_MATMUL_HPP + +#include "catlass/catlass.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/epilogue/tile/copy_gm_to_ub.hpp" +#include "catlass/epilogue/tile/copy_ub_to_gm.hpp" +#include "catlass/gemm/kernel/padding_matmul.hpp" + +namespace Catlass::Gemm::Kernel { + +template +class OptimizedMatmul +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using ElementA = typename BlockMmad::ElementA; + using ElementB = typename BlockMmad::ElementB; + using LayoutWA = typename BlockMmad::LayoutA; + using LayoutWB = typename BlockMmad::LayoutB; + + template + struct LayoutHelper { + using type = typename T::LayoutIn; + }; + template <> + struct LayoutHelper { + using type = void; + }; + + using LayoutA = std::conditional_t, typename BlockMmad::LayoutA, + typename LayoutHelper::type>; + using LayoutB = std::conditional_t, typename BlockMmad::LayoutB, + typename LayoutHelper::type>; + + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + + using BlockScheduler = BlockScheduler_; + + /// Parameters structure + struct ParamsBase { + // Data members + GemmCoord problemShape; + GM_ADDR ptrA; + LayoutA layoutA; + GM_ADDR ptrB; + LayoutB layoutB; + GM_ADDR ptrC; + LayoutC layoutC; + + // Methods + CATLASS_HOST_DEVICE + ParamsBase() {} + + CATLASS_HOST_DEVICE + ParamsBase(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, + GM_ADDR ptrC_, LayoutC layoutC_) + : problemShape(problemShape_), + ptrA(ptrA_), + layoutA(layoutA_), + ptrB(ptrB_), + layoutB(layoutB_), + ptrC(ptrC_), + layoutC(layoutC_) + {} + }; + + template + struct KernelParams : public ParamsBase { + // Data members + using LayoutWA = typename BlockMmad::LayoutA; + using LayoutWB = typename BlockMmad::LayoutB; + + GM_ADDR ptrWA; + LayoutWA layoutWA; + GM_ADDR ptrWB; + LayoutWB layoutWB; + + // Methods + CATLASS_HOST_DEVICE + KernelParams() {} + + CATLASS_HOST_DEVICE + KernelParams(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, + GM_ADDR ptrC_, LayoutC layoutC_, GM_ADDR ptrWA_, LayoutWA layoutWA_, GM_ADDR ptrWB_, + LayoutWB layoutWB_) + : ParamsBase(problemShape_, ptrA_, layoutA_, ptrB_, layoutB_, ptrC_, layoutC_), + ptrWA(ptrWA_), + layoutWA(layoutWA_), + ptrWB(ptrWB_), + layoutWB(layoutWB_) + {} + }; + + template <> + struct KernelParams : public ParamsBase { + // Data members + using LayoutWA = typename BlockMmad::LayoutA; + + GM_ADDR ptrWA; + LayoutWA layoutWA; + + // Methods + CATLASS_HOST_DEVICE + KernelParams() {} + + CATLASS_HOST_DEVICE + KernelParams(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, + GM_ADDR ptrC_, LayoutC layoutC_, GM_ADDR ptrWA_, LayoutWA layoutWA_) + : ParamsBase(problemShape_, ptrA_, layoutA_, ptrB_, layoutB_, ptrC_, layoutC_), + ptrWA(ptrWA_), + layoutWA(layoutWA_) + {} + }; + + template <> + struct KernelParams : public ParamsBase { + // Data members + using LayoutWB = typename BlockMmad::LayoutB; + + GM_ADDR ptrWB; + LayoutWB layoutWB; + ; + + // Methods + CATLASS_HOST_DEVICE + KernelParams() {} + + CATLASS_HOST_DEVICE + KernelParams(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, + GM_ADDR ptrC_, LayoutC layoutC_, GM_ADDR ptrWB_, LayoutWB layoutWB_) + : ParamsBase(problemShape_, ptrA_, layoutA_, ptrB_, layoutB_, ptrC_, layoutC_), + ptrWB(ptrWB_), + layoutWB(layoutWB_) + {} + }; + + template <> + struct KernelParams : public ParamsBase { + // Methods + CATLASS_HOST_DEVICE + KernelParams() {} + + CATLASS_HOST_DEVICE + KernelParams(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, + GM_ADDR ptrC_, LayoutC layoutC_) + : ParamsBase(problemShape_, ptrA_, layoutA_, ptrB_, layoutB_, ptrC_, layoutC_) + {} + }; + + using Params = KernelParams, !std::is_void_v>; + + struct Arguments { + GemmCoord problemShape; + GM_ADDR ptrA; + GM_ADDR ptrB; + GM_ADDR ptrC; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + constexpr bool isPaddingA = !std::is_void_v; + constexpr bool isPaddingB = !std::is_void_v; + size_t workspaceSize = 0; + if constexpr (isPaddingA) { + if constexpr (PrologueA::paddingTag == PaddingTag::PADDING_BLOCK_ND) { + workspaceSize += PrologueA::GetWorkspaceSize(args.problemShape.m(), args.problemShape.k(), + L1TileShape::M, L1TileShape::K); + } else if constexpr (PrologueA::paddingTag == PaddingTag::PADDING_ND) { + // Optimal bandwidth for 512 Byte aligned reads + workspaceSize += + PrologueA::GetWorkspaceSize(args.problemShape.m(), args.problemShape.k(), 512 / sizeof(ElementA)); + } + } + if constexpr (isPaddingB) { + if constexpr (PrologueB::paddingTag == PaddingTag::PADDING_BLOCK_ND) { + workspaceSize += PrologueB::GetWorkspaceSize(args.problemShape.k(), args.problemShape.n(), + L1TileShape::K, L1TileShape::N); + } else if constexpr (PrologueB::paddingTag == PaddingTag::PADDING_ND) { + // Optimal bandwidth for 512 Byte aligned reads + workspaceSize += + PrologueB::GetWorkspaceSize(args.problemShape.k(), args.problemShape.n(), 512 / sizeof(ElementB)); + } + } + return workspaceSize; + } + + static auto ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + constexpr bool isPaddingA = !std::is_void_v; + constexpr bool isPaddingB = !std::is_void_v; + LayoutA layoutA = LayoutA::template MakeLayout(args.problemShape.m(), args.problemShape.k()); + LayoutB layoutB = LayoutB::template MakeLayout(args.problemShape.k(), args.problemShape.n()); + LayoutC layoutC = LayoutC::template MakeLayout(args.problemShape.m(), args.problemShape.n()); + + uint8_t *gmWA = nullptr; + uint8_t *gmWB = nullptr; + size_t sizeWA = 0; + if constexpr (isPaddingA) { + gmWA = workspace; + if constexpr (PrologueA::paddingTag == PaddingTag::PADDING_BLOCK_ND) { + sizeWA += PrologueA::GetWorkspaceSize(args.problemShape.m(), args.problemShape.k(), L1TileShape::M, + L1TileShape::K); + } else if constexpr (PrologueA::paddingTag == PaddingTag::PADDING_ND) { + // Optimal bandwidth for 512 Byte aligned reads + sizeWA += + PrologueA::GetWorkspaceSize(args.problemShape.m(), args.problemShape.k(), 512 / sizeof(ElementA)); + } + } + if constexpr (isPaddingB) { + gmWB = workspace + sizeWA; + } + + if constexpr (isPaddingA && isPaddingB) { + typename PrologueA::LayoutOut layoutWA; + typename PrologueB::LayoutOut layoutWB; + if constexpr (PrologueA::paddingTag == PaddingTag::PADDING_BLOCK_ND) { + layoutWA = PrologueA::GetWorkspaceLayout(layoutA, L1TileShape::M, L1TileShape::K); + } else if constexpr (PrologueA::paddingTag == PaddingTag::PADDING_ND) { + layoutWA = PrologueA::GetWorkspaceLayout(layoutA, 512 / sizeof(ElementA)); + } + if constexpr (PrologueB::paddingTag == PaddingTag::PADDING_BLOCK_ND) { + layoutWB = PrologueB::GetWorkspaceLayout(layoutB, L1TileShape::K, L1TileShape::N); + } else if constexpr (PrologueB::paddingTag == PaddingTag::PADDING_ND) { + // Optimal bandwidth for 512 Byte aligned reads + layoutWB = PrologueB::GetWorkspaceLayout(layoutB, 512 / sizeof(ElementB)); + } + Params params{args.problemShape, args.ptrA, layoutA, args.ptrB, layoutB, args.ptrC, + layoutC, gmWA, layoutWA, gmWB, layoutWB}; + return params; + } else if constexpr (isPaddingA) { + typename PrologueA::LayoutOut layoutWA; + if constexpr (PrologueA::paddingTag == PaddingTag::PADDING_BLOCK_ND) { + layoutWA = PrologueA::GetWorkspaceLayout(layoutA, L1TileShape::M, L1TileShape::K); + } else if constexpr (PrologueA::paddingTag == PaddingTag::PADDING_ND) { + // Optimal bandwidth for 512 Byte aligned reads + layoutWA = PrologueA::GetWorkspaceLayout(layoutA, 512 / sizeof(ElementA)); + } + Params params{args.problemShape, args.ptrA, layoutA, args.ptrB, layoutB, + args.ptrC, layoutC, gmWA, layoutWA}; + return params; + } else if constexpr (isPaddingB) { + typename PrologueB::LayoutOut layoutWB; + if constexpr (PrologueB::paddingTag == PaddingTag::PADDING_BLOCK_ND) { + layoutWB = PrologueB::GetWorkspaceLayout(layoutB, L1TileShape::K, L1TileShape::N); + } else if constexpr (PrologueB::paddingTag == PaddingTag::PADDING_ND) { + // Optimal bandwidth for 512 Byte aligned reads + layoutWB = PrologueB::GetWorkspaceLayout(layoutB, 512 / sizeof(ElementB)); + } + Params params{args.problemShape, args.ptrA, layoutA, args.ptrB, layoutB, + args.ptrC, layoutC, gmWB, layoutWB}; + return params; + } else { + Params params{args.problemShape, args.ptrA, layoutA, args.ptrB, layoutB, args.ptrC, layoutC}; + return params; + } + } + + // Methods + CATLASS_DEVICE + OptimizedMatmul() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + if constexpr (!std::is_void_v) { + AscendC::GlobalTensor gmA; + AscendC::GlobalTensor gmWA; + gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(params.ptrA)); + gmWA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(params.ptrWA)); + PrologueA prologueA(resource); + prologueA(gmWA, gmA, params.layoutWA, params.layoutA); + } + + if constexpr (!std::is_void_v) { + AscendC::GlobalTensor gmB; + AscendC::GlobalTensor gmWB; + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(params.ptrB)); + gmWB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(params.ptrWB)); + PrologueB prologueB(resource); + prologueB(gmWB, gmB, params.layoutWB, params.layoutB); + // 0x0 synchronization control between AI Core + } + if constexpr (!std::is_void_v || !std::is_void_v) { + Catlass::Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + Catlass::Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishPadding); + } + + AscendC::PipeBarrier(); + } + + /// Executes matmul + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + if constexpr (!std::is_void_v || !std::is_void_v) { + Catlass::Arch::CrossCoreWaitFlag(flagAivFinishPadding); + } + + BlockScheduler matmulBlockScheduler(params.problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops(); + + typename BlockMmad::LayoutA layoutA; + typename BlockMmad::LayoutB layoutB; + + // Represent the full gm + AscendC::GlobalTensor gmA; + if constexpr (std::is_void_v) { + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrA); + layoutA = params.layoutA; + } else { + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrWA); + layoutA = params.layoutWA; + } + AscendC::GlobalTensor gmB; + if constexpr (std::is_void_v) { + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrB); + layoutB = params.layoutB; + } else { + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrWB); + layoutB = params.layoutWB; + } + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrC); + + BlockMmad blockMmad(resource); + + for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + // Compute block location + GemmCoord blockCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockCoord); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = params.layoutC.GetOffset(offsetC); + + bool isFirstBlock = (loopIdx == AscendC::GetBlockIdx()); + bool hasNextBlock = false; + GemmCoord nextBlockIdCoord; + GemmCoord nextActualBlockShape; + if (loopIdx + AscendC::GetBlockNum() < coreLoops) { + hasNextBlock = true; + nextBlockIdCoord = matmulBlockScheduler.GetBlockCoord(loopIdx + AscendC::GetBlockNum()); + nextActualBlockShape = matmulBlockScheduler.GetActualBlockShape(nextBlockIdCoord); + } + MatrixCoord offsetNextA{nextBlockIdCoord.m() * L1TileShape::M, nextBlockIdCoord.k() * L1TileShape::K}; + MatrixCoord offsetNextB{nextBlockIdCoord.k() * L1TileShape::K, nextBlockIdCoord.n() * L1TileShape::N}; + int64_t gmOffsetNextA = layoutA.GetOffset(offsetNextA); + int64_t gmOffsetNextB = layoutB.GetOffset(offsetNextB); + + // Compute block-scoped matrix multiply-add + blockMmad(gmA[gmOffsetA], layoutA, gmB[gmOffsetB], layoutB, gmC[gmOffsetC], params.layoutC, + gmA[gmOffsetNextA], gmB[gmOffsetNextB], actualBlockShape, nextActualBlockShape, isFirstBlock, + hasNextBlock); + } + + AscendC::PipeBarrier(); + } + +private: + static constexpr Arch::FlagID FLAG_AIV_FINISH_STORE = 0; + Arch::CrossCoreFlag flagAivFinishPadding{FLAG_AIV_FINISH_STORE}; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_OPTIMIZED_MATMUL_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/optimized_matmul_tla.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/optimized_matmul_tla.hpp new file mode 100644 index 00000000..9f224b51 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/optimized_matmul_tla.hpp @@ -0,0 +1,421 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_OPTIMIZED_MATMUL_TLA_HPP +#define CATLASS_GEMM_KERNEL_OPTIMIZED_MATMUL_TLA_HPP + +#include "catlass/catlass.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/gemm/tile/tile_copy_tla.hpp" +#include "tla/layout.hpp" +#include "tla/tensor.hpp" + +namespace Catlass::Gemm::Kernel { + +template +struct PaddingMatrixBlockND { +public: + using ArchTag = ArchTag_; + using TensorIn = TensorIn_; + using TensorOut = TensorOut_; + using Element = typename TensorIn::Element; + using LayoutIn = typename TensorIn::Layout; + using LayoutOut = typename TensorOut::Layout; + + using LayoutInner = tla::Layout, tla::Stride>>; + using TensorInnerUb = tla::Tensor, LayoutInner, tla::Coord, + AscendC::TPosition::VECCALC>; + using TensorInnerSrcGm = + tla::Tensor, LayoutInner, tla::Coord, AscendC::TPosition::GM>; + + using LayoutInnerDstGm = tla::Layout, tla::Shape>, + tla::Stride, tla::Stride, int64_t>>>; + using TensorInnerDstGm = tla::Tensor, LayoutInnerDstGm, tla::Coord, + AscendC::TPosition::GM>; + + using CopyGm2Ub = Catlass::Gemm::Tile::TileCopyTla; + using CopyUb2Gm = Catlass::Gemm::Tile::TileCopyTlaExt; + + CopyGm2Ub copyGm2Ub; + CopyUb2Gm copyUb2Gm; + + CATLASS_DEVICE + PaddingMatrixBlockND(Arch::Resource &resource) + { + int64_t bufferOffset = 0; + for (uint32_t i = 0; i < BUFFER_NUM; i++) { + // 在ub上分配空间 + inputBuffer[i] = resource.ubBuf.template GetBufferByByte(bufferOffset * sizeof(Element)); + // 每一片UB上的开均分到BUFFER_NUM的空间 + bufferOffset += COMPUTE_LENGTH; + } + } + + template + CATLASS_DEVICE auto GetPaddingTensorSrc(Tensor const &tensor) + { + if constexpr (std::is_same_v) { + return tensor; + } else { + auto shape = tla::MakeShape(tla::get<1>(tensor.shape()), tla::get<0>(tensor.shape())); + auto stride = tla::MakeStride(tla::get<1>(tensor.stride()), tla::get<0>(tensor.stride())); + return tla::MakeTensor(tensor.data(), MakeLayout(shape, stride), Arch::PositionGM{}); + } + } + + template + CATLASS_DEVICE auto GetPaddingTensorDst(Tensor const &tensor) + { + if constexpr (std::is_same_v) { + return tensor; + } else { + auto shape = tla::MakeShape(tla::get<1>(tensor.shape()), tla::get<0>(tensor.shape())); + auto stride = tla::MakeStride(tla::get<1>(tensor.stride()), tla::get<0>(tensor.stride())); + return tla::MakeTensor(tensor.data(), MakeLayout(shape, stride), Arch::PositionGM{}); + } + } + + template + CATLASS_DEVICE void operator()(TensorDst &tensorDst, TensorSrc const &tensorSrc) + { + auto paddingTensorSrc = GetPaddingTensorSrc(tensorSrc); + auto paddingTensorDst = GetPaddingTensorDst(tensorDst); + + uint32_t aivNum = AscendC::GetBlockNum() * AscendC::GetSubBlockNum(); + uint32_t aivId = AscendC::GetBlockIdx(); + + // 按照行切块,每行为一个tile块 + uint32_t tilesNum = tla::get<0>(paddingTensorSrc.shape()); + uint32_t tileLen = tla::get<1>(paddingTensorSrc.shape()); + uint32_t roundTileLen = RoundUp(tla::get<1>(paddingTensorSrc.shape())); + // 计算每一个aiv要计算的大小,对于剩余的工作从前向后增加 + uint32_t tilesPerAiv = tilesNum / aivNum; + uint32_t tileRemain = tilesNum % aivNum; + if (aivId < tileRemain) { + tilesPerAiv++; + } + // 因为前面进行了工作重分配,所以相应后面的aiv处理的偏移量要后移 + uint32_t mIdx = aivId * tilesPerAiv; + if (aivId >= tileRemain) { + mIdx += tileRemain; + } + MatrixCoord blockOffset(mIdx, 0); + // 配置UB到GM的信号量 + AscendC::SetFlag(eventIds[0]); + AscendC::SetFlag(eventIds[1]); + uint32_t coreLoops{0}; + if (roundTileLen > COMPUTE_LENGTH) { + // Handle the same tile on multiple loops. + uint32_t loopsPerTile = (tileLen + COMPUTE_LENGTH - 1) / COMPUTE_LENGTH; + coreLoops = tilesPerAiv * loopsPerTile; + for (uint32_t loopIdx = 0; loopIdx < coreLoops; ++loopIdx) { + uint32_t tileIdx = loopIdx / loopsPerTile; + uint32_t inTileLoopIdx = loopIdx % loopsPerTile; + auto offset = tla::MakeCoord(mIdx + tileIdx, inTileLoopIdx * COMPUTE_LENGTH); + uint32_t actualDataNum = COMPUTE_LENGTH; + if (tileLen - inTileLoopIdx * COMPUTE_LENGTH < COMPUTE_LENGTH) { + actualDataNum = tileLen - inTileLoopIdx * COMPUTE_LENGTH; + } + + AscendC::WaitFlag(eventIds[bufferIndex]); + auto tensorTileSrc = + GetTile(paddingTensorSrc, offset, tla::MakeShape(static_cast(1), actualDataNum)); + auto tensorTileDst = + GetTile(paddingTensorDst, offset, tla::MakeShape(static_cast(1), actualDataNum)); + + auto layoutDstUb = MakeLayout(tla::MakeShape(static_cast(1), actualDataNum), + tla::MakeStride(static_cast(COMPUTE_LENGTH), tla::Int<1>{})); + auto tensorDstUb = tla::MakeTensor(inputBuffer[bufferIndex], layoutDstUb, Arch::PositionUB{}); + + copyGm2Ub(tensorDstUb, tensorTileSrc); + AscendC::SetFlag(eventIds[bufferIndex]); + AscendC::WaitFlag(eventIds[bufferIndex]); + + auto layoutSrcUb = MakeLayout( + tla::MakeShape(CeilDiv(actualDataNum, tla::get<1, 0>(paddingTensorDst.shape())), + tla::get<1, 0>(paddingTensorDst.shape())), + tla::MakeStride(static_cast(tla::get<1, 0>(paddingTensorDst.shape())), tla::Int<1>{})); + auto tensorSrcUb = tla::MakeTensor(inputBuffer[bufferIndex], layoutSrcUb, Arch::PositionUB{}); + copyUb2Gm(tensorTileDst, tensorSrcUb); + AscendC::SetFlag(eventIds[bufferIndex]); + + bufferIndex = (bufferIndex + 1) % BUFFER_NUM; + } + } else { + // Handle multiple tile each loop. + uint32_t tilesPerLoop = COMPUTE_LENGTH / roundTileLen; + coreLoops = (tilesPerAiv + tilesPerLoop - 1) / tilesPerLoop; + for (uint32_t loopIdx = 0; loopIdx < coreLoops; ++loopIdx) { + uint32_t tileIdx = loopIdx * tilesPerLoop; + uint32_t actualTilesNum = tilesPerLoop; + if (tilesPerAiv - tileIdx < tilesPerLoop) { + actualTilesNum = tilesPerAiv - tileIdx; + } + auto offset = tla::MakeCoord(mIdx + tileIdx, static_cast(0)); + + AscendC::WaitFlag(eventIds[bufferIndex]); + auto tensorTileSrc = GetTile(paddingTensorSrc, offset, tla::MakeShape(actualTilesNum, tileLen)); + + auto layoutDstUb = MakeLayout(tla::MakeShape(actualTilesNum, tileLen), + tla::MakeStride(static_cast(roundTileLen), tla::Int<1>{})); + auto tensorDstUb = tla::MakeTensor(inputBuffer[bufferIndex], layoutDstUb, Arch::PositionUB{}); + + copyGm2Ub(tensorDstUb, tensorTileSrc); + AscendC::SetFlag(eventIds[bufferIndex]); + AscendC::WaitFlag(eventIds[bufferIndex]); + + auto layoutSrcUb = MakeLayout( + tla::MakeShape(CeilDiv(tileLen, tla::get<1, 0>(paddingTensorDst.shape())), + tla::get<1, 0>(paddingTensorDst.shape())), + tla::MakeStride(static_cast(tla::get<1, 0>(paddingTensorDst.shape())), tla::Int<1>{})); + for (uint32_t i = 0; i < actualTilesNum; ++i) { + auto tensorTileDst = + GetTile(paddingTensorDst, tla::MakeCoord(mIdx + tileIdx + i, static_cast(0)), + tla::MakeShape(static_cast(1), tileLen)); + auto tensorSrcUb = + tla::MakeTensor(inputBuffer[bufferIndex][i * roundTileLen], layoutSrcUb, Arch::PositionUB{}); + copyUb2Gm(tensorTileDst, tensorSrcUb); + } + AscendC::SetFlag(eventIds[bufferIndex]); + + bufferIndex = (bufferIndex + 1) % BUFFER_NUM; + } + } + AscendC::WaitFlag(eventIds[0]); + AscendC::WaitFlag(eventIds[1]); + } + + CATLASS_DEVICE + ~PaddingMatrixBlockND() {} + +private: + static const uint32_t BUFFER_NUM = 2; + AscendC::LocalTensor inputBuffer[BUFFER_NUM]; + AscendC::TEventID eventIds[BUFFER_NUM] = {EVENT_ID0, EVENT_ID1}; + uint32_t bufferIndex{0}; + static_assert(BUFFER_NUM * COMPUTE_LENGTH * sizeof(Element) <= ArchTag::UB_SIZE, "Exceeding the UB space!"); +}; + +template +class OptimizedMatmulTla +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using ElementA = typename BlockMmad::ElementA; + using ElementB = typename BlockMmad::ElementB; + using LayoutWA = typename BlockMmad::LayoutA; + using LayoutWB = typename BlockMmad::LayoutB; + + template + struct LayoutHelper { + using type = typename T::LayoutIn; + }; + template <> + struct LayoutHelper { + using type = void; + }; + using LayoutA = std::conditional_t, LayoutWA, typename LayoutHelper::type>; + using LayoutB = std::conditional_t, LayoutWB, typename LayoutHelper::type>; + + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + + using BlockScheduler = BlockScheduler_; + + static constexpr uint32_t L1_TILE_M = tla::get<0>(L1TileShape{}); + static constexpr uint32_t L1_TILE_N = tla::get<1>(L1TileShape{}); + static constexpr uint32_t L1_TILE_K = tla::get<2>(L1TileShape{}); + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + GM_ADDR ptrA; + LayoutA layoutA; + GM_ADDR ptrB; + LayoutB layoutB; + GM_ADDR ptrC; + LayoutC layoutC; + GM_ADDR ptrWA; + LayoutWA layoutWA; + GM_ADDR ptrWB; + LayoutWB layoutWB; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, + GM_ADDR ptrC_, LayoutC layoutC_, GM_ADDR ptrWA_, LayoutWA layoutWA_, GM_ADDR ptrWB_, LayoutWB layoutWB_) + : problemShape(problemShape_), + ptrA(ptrA_), + layoutA(layoutA_), + ptrB(ptrB_), + layoutB(layoutB_), + ptrC(ptrC_), + layoutC(layoutC_), + ptrWA(ptrWA_), + layoutWA(layoutWA_), + ptrWB(ptrWB_), + layoutWB(layoutWB_) + {} + }; + + struct Arguments { + GemmCoord problemShape; + uint8_t *ptrA; + LayoutA layoutA; + uint8_t *ptrB; + LayoutB layoutB; + uint8_t *ptrC; + LayoutC layoutC; + uint8_t *ptrWA; + LayoutWA layoutWA; + uint8_t *ptrWB; + LayoutWB layoutWB; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + return 0; + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + Params params{args.problemShape, args.ptrA, args.layoutA, args.ptrB, args.layoutB, args.ptrC, + args.layoutC, args.ptrWA, args.layoutWA, args.ptrWB, args.layoutWB}; + return params; + } + + // Methods + CATLASS_DEVICE + OptimizedMatmulTla() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + if constexpr (!std::is_void_v) { + AscendC::GlobalTensor gmA; + AscendC::GlobalTensor gmWA; + gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(params.ptrA)); + gmWA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(params.ptrWA)); + auto tensorA = tla::MakeTensor(gmA, params.layoutA, Arch::PositionGM{}); + auto tensorWA = tla::MakeTensor(gmWA, params.layoutWA, Arch::PositionGM{}); + PaddingA paddingA(resource); + paddingA(tensorWA, tensorA); + } + + if constexpr (!std::is_void_v) { + AscendC::GlobalTensor gmB; + AscendC::GlobalTensor gmWB; + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(params.ptrB)); + gmWB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(params.ptrWB)); + auto tensorB = tla::MakeTensor(gmB, params.layoutB, Arch::PositionGM{}); + auto tensorWB = tla::MakeTensor(gmWB, params.layoutWB, Arch::PositionGM{}); + PaddingB paddingB(resource); + paddingB(tensorWB, tensorB); + // 0x0 synchronization control between AI Core + } + if constexpr (!std::is_void_v || !std::is_void_v) { + Catlass::Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + Catlass::Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishPadding); + } + + AscendC::PipeBarrier(); + } + + /// Executes matmul + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + if (!std::is_void_v || !std::is_void_v) { + Catlass::Arch::CrossCoreWaitFlag(flagAivFinishPadding); + } + + BlockScheduler matmulBlockScheduler(params.problemShape, MakeCoord(L1_TILE_M, L1_TILE_N)); + uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops(); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrWA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrWB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrC); + + auto tensorA = tla::MakeTensor(gmA, params.layoutWA, Arch::PositionGM{}); + auto tensorB = tla::MakeTensor(gmB, params.layoutWB, Arch::PositionGM{}); + auto tensorC = tla::MakeTensor(gmC, params.layoutC, Arch::PositionGM{}); + + BlockMmad blockMmad(resource); + + for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + // Compute block location + GemmCoord blockCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockCoord); + + // Compute initial location in logical coordinates + auto tensorBlockA = GetTile(tensorA, tla::MakeCoord(blockCoord.m() * L1_TILE_M, blockCoord.k() * L1_TILE_K), + tla::MakeShape(actualBlockShape.m(), actualBlockShape.k())); + auto tensorBlockB = GetTile(tensorB, tla::MakeCoord(blockCoord.k() * L1_TILE_K, blockCoord.n() * L1_TILE_N), + tla::MakeShape(actualBlockShape.k(), actualBlockShape.n())); + auto tensorBlockC = GetTile(tensorC, tla::MakeCoord(blockCoord.m() * L1_TILE_M, blockCoord.n() * L1_TILE_N), + tla::MakeShape(actualBlockShape.m(), actualBlockShape.n())); + + bool isFirstBlock = (loopIdx == AscendC::GetBlockIdx()); + bool hasNextBlock = false; + GemmCoord nextBlockCoord; + GemmCoord nextActualBlockShape; + if (loopIdx + AscendC::GetBlockNum() < coreLoops) { + hasNextBlock = true; + nextBlockCoord = matmulBlockScheduler.GetBlockCoord(loopIdx + AscendC::GetBlockNum()); + nextActualBlockShape = matmulBlockScheduler.GetActualBlockShape(nextBlockCoord); + } + + auto nextTensorBlockA = + GetTile(tensorA, tla::MakeCoord(nextBlockCoord.m() * L1_TILE_M, nextBlockCoord.k() * L1_TILE_K), + tla::MakeShape(nextActualBlockShape.m(), nextActualBlockShape.k())); + auto nextTensorBlockB = + GetTile(tensorB, tla::MakeCoord(nextBlockCoord.k() * L1_TILE_K, nextBlockCoord.n() * L1_TILE_N), + tla::MakeShape(nextActualBlockShape.k(), nextActualBlockShape.n())); + + // Compute block-scoped matrix multiply-add + blockMmad(tensorBlockA, tensorBlockB, tensorBlockC, nextTensorBlockA, nextTensorBlockB, actualBlockShape, + nextActualBlockShape, isFirstBlock, hasNextBlock); + } + + AscendC::PipeBarrier(); + } + +private: + static constexpr Arch::FlagID FLAG_AIV_FINISH_STORE = 0; + Arch::CrossCoreFlag flagAivFinishPadding{FLAG_AIV_FINISH_STORE}; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_OPTIMIZED_MATMUL_TLA_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/padding_matmul.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/padding_matmul.hpp new file mode 100644 index 00000000..e6c46315 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/padding_matmul.hpp @@ -0,0 +1,594 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_PADDING_MATMUL_HPP +#define CATLASS_GEMM_KERNEL_PADDING_MATMUL_HPP + +#include "catlass/catlass.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/epilogue/tile/copy_gm_to_ub.hpp" +#include "catlass/epilogue/tile/copy_ub_to_gm.hpp" + +namespace Catlass::Gemm::Kernel { + +enum class PaddingTag { NO_PADDING, PADDING_ND, PADDING_BLOCK_ND }; + +template +struct PaddingMatrixBlockND { +public: + using ArchTag = ArchTag_; + using Element = Element_; + using LayoutIn = LayoutIn_; + using LayoutOut = LayoutOut_; + using ComputeLayout = Catlass::layout::RowMajor; + using ComputeLayoutDst = Catlass::layout::PaddingRowMajor; + using CopyGm2Ub = Catlass::Epilogue::Tile::CopyGm2Ub>; + using CopyUb2Gm = Catlass::Epilogue::Tile::CopyUb2Gm>; + + static const PaddingTag paddingTag = PaddingTag::PADDING_BLOCK_ND; + CATLASS_HOST_DEVICE static LayoutOut GetWorkspaceLayout(LayoutIn &layout, uint32_t rowAlign, uint32_t colAlign) + { + return LayoutOut(layout.shape(0), layout.shape(1), rowAlign, colAlign); + } + static size_t GetWorkspaceSize(uint32_t rows, uint32_t cols, uint32_t rowAlign, uint32_t colAlign) + { + return static_cast(RoundUp(rows, rowAlign)) * RoundUp(cols, colAlign) * sizeof(Element); + } + + CopyGm2Ub copyGm2Ub; + CopyUb2Gm copyUb2Gm; + + CATLASS_DEVICE + PaddingMatrixBlockND(Arch::Resource &resource) + { + int64_t bufferOffset = 0; + for (uint32_t i = 0; i < BUFFER_NUM; i++) { + inputBuffer[i] = resource.ubBuf.template GetBufferByByte(bufferOffset * sizeof(Element)); + bufferOffset += COMPUTE_LENGTH; + } + } + + CATLASS_DEVICE + ComputeLayout GetPaddingComputeLayout(layout::RowMajor const &layout) + { + return ComputeLayout(layout.shape(0), layout.shape(1), layout.stride(0)); + } + + CATLASS_DEVICE + ComputeLayout GetPaddingComputeLayout(layout::ColumnMajor const &layout) + { + return ComputeLayout(layout.shape(1), layout.shape(0), layout.stride(1)); + } + + CATLASS_DEVICE + ComputeLayoutDst GetPaddingComputeLayout(layout::PaddingRowMajor const &layout) + { + return ComputeLayoutDst(layout.shape(0) * layout.shape(1), layout.shape(2) * layout.shape(3), layout.shape(0), + layout.shape(2)); + } + + CATLASS_DEVICE + ComputeLayoutDst GetPaddingComputeLayout(layout::PaddingColumnMajor const &layout) + { + return ComputeLayoutDst(layout.shape(2) * layout.shape(3), layout.shape(0) * layout.shape(1), layout.shape(2), + layout.shape(0)); + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &dst, AscendC::GlobalTensor const &src, + LayoutOut layoutDst, LayoutIn layoutSrc) + { + auto computeLayoutSrc = GetPaddingComputeLayout(layoutSrc); + auto computeLayoutDst = GetPaddingComputeLayout(layoutDst); + + uint32_t aivNum = AscendC::GetBlockNum() * AscendC::GetSubBlockNum(); + uint32_t aivId = AscendC::GetBlockIdx(); + + // Each line is a tile. + uint32_t tilesNum = computeLayoutSrc.shape(0); + uint32_t tileLen = computeLayoutSrc.shape(1); + uint32_t roundTileLen = RoundUp(computeLayoutSrc.shape(1)); + + uint32_t tilesPerAiv = tilesNum / aivNum; + uint32_t tileRemain = tilesNum % aivNum; + if (aivId < tileRemain) { + tilesPerAiv++; + } + uint32_t mIdx = aivId * tilesPerAiv; + if (aivId >= tileRemain) { + mIdx += tileRemain; + } + MatrixCoord blockOffset(mIdx, 0); + + AscendC::SetFlag(eventIds[0]); + AscendC::SetFlag(eventIds[1]); + uint32_t coreLoops{0}; + if (roundTileLen > COMPUTE_LENGTH) { + // Handle the same tile on multiple loops. + uint32_t loopsPerTile = (tileLen + COMPUTE_LENGTH - 1) / COMPUTE_LENGTH; + coreLoops = tilesPerAiv * loopsPerTile; + for (uint32_t loopIdx = 0; loopIdx < coreLoops; ++loopIdx) { + uint32_t tileIdx = loopIdx / loopsPerTile; + uint32_t inTileLoopIdx = loopIdx % loopsPerTile; + MatrixCoord loopOffset(tileIdx, inTileLoopIdx * COMPUTE_LENGTH); + uint64_t gmSrcOffset = computeLayoutSrc.GetOffset(blockOffset + loopOffset); + uint32_t actualDataNum = COMPUTE_LENGTH; + if (tileLen - inTileLoopIdx * COMPUTE_LENGTH < COMPUTE_LENGTH) { + actualDataNum = tileLen - inTileLoopIdx * COMPUTE_LENGTH; + } + + AscendC::WaitFlag(eventIds[bufferIndex]); + ComputeLayout srcLayout = computeLayoutSrc.GetTileLayout(MatrixCoord(1, actualDataNum)); + ComputeLayout ubLayout = ComputeLayout{1, actualDataNum}; + ComputeLayout dstLayout = ComputeLayout(CeilDiv(actualDataNum, computeLayoutDst.shape(2)), + computeLayoutDst.shape(2), computeLayoutDst.stride(3)); + + copyGm2Ub(inputBuffer[bufferIndex], src[gmSrcOffset], ubLayout, srcLayout); + AscendC::SetFlag(eventIds[bufferIndex]); + AscendC::WaitFlag(eventIds[bufferIndex]); + + uint64_t gmDstOffset = computeLayoutDst.GetOffset(blockOffset + loopOffset); + copyUb2Gm(dst[gmDstOffset], inputBuffer[bufferIndex], dstLayout, ubLayout); + AscendC::SetFlag(eventIds[bufferIndex]); + + bufferIndex = (bufferIndex + 1) % BUFFER_NUM; + } + } else { + // Handle multiple tile each loop. + uint32_t tilesPerLoop = COMPUTE_LENGTH / roundTileLen; + coreLoops = (tilesPerAiv + tilesPerLoop - 1) / tilesPerLoop; + for (uint32_t loopIdx = 0; loopIdx < coreLoops; ++loopIdx) { + uint32_t tileIdx = loopIdx * tilesPerLoop; + MatrixCoord tileOffset(tileIdx, 0); + uint64_t gmSrcOffset = computeLayoutSrc.GetOffset(blockOffset + tileOffset); + uint32_t actualTilesNum = tilesPerLoop; + if (tilesPerAiv - tileIdx < tilesPerLoop) { + actualTilesNum = tilesPerAiv - tileIdx; + } + + AscendC::WaitFlag(eventIds[bufferIndex]); + ComputeLayout srcLayout = computeLayoutSrc.GetTileLayout(MatrixCoord(actualTilesNum, tileLen)); + ComputeLayout ubLayout = ComputeLayout{actualTilesNum, tileLen, roundTileLen}; + ComputeLayout dstLayout = ComputeLayout{CeilDiv(tileLen, computeLayoutDst.shape(2)), + computeLayoutDst.shape(2), computeLayoutDst.stride(3)}; + + copyGm2Ub(inputBuffer[bufferIndex], src[gmSrcOffset], ubLayout, srcLayout); + AscendC::SetFlag(eventIds[bufferIndex]); + AscendC::WaitFlag(eventIds[bufferIndex]); + + for (uint32_t i = 0; i < actualTilesNum; ++i) { + uint64_t gmDstOffset = computeLayoutDst.GetOffset(blockOffset + tileOffset + MatrixCoord(i, 0)); + uint64_t ubOffset = ubLayout.GetOffset(MatrixCoord(i, 0)); + copyUb2Gm(dst[gmDstOffset], inputBuffer[bufferIndex][ubOffset], dstLayout, ubLayout); + } + AscendC::SetFlag(eventIds[bufferIndex]); + + bufferIndex = (bufferIndex + 1) % BUFFER_NUM; + } + } + AscendC::WaitFlag(eventIds[0]); + AscendC::WaitFlag(eventIds[1]); + } + + CATLASS_DEVICE + ~PaddingMatrixBlockND() {} + +private: + static const uint32_t BUFFER_NUM = 2; + AscendC::LocalTensor inputBuffer[BUFFER_NUM]; + AscendC::TEventID eventIds[BUFFER_NUM] = {EVENT_ID0, EVENT_ID1}; + uint32_t bufferIndex{0}; + static_assert(BUFFER_NUM * COMPUTE_LENGTH * sizeof(Element) <= ArchTag::UB_SIZE, "Exceeding the UB space!"); + static_assert(std::is_same_v || std::is_same_v, + "Unsported layout for PaddingMatrixBlockNd!"); +}; + +template +struct PaddingMatrixND { +public: + using ArchTag = ArchTag_; + using Element = Element_; + using Layout = Layout_; + using CopyGm2Ub = Catlass::Epilogue::Tile::CopyGm2Ub>; + using CopyUb2Gm = Catlass::Epilogue::Tile::CopyUb2Gm>; + using ComputeLayout = Catlass::layout::RowMajor; + + static const PaddingTag paddingTag = PaddingTag::PADDING_ND; + using LayoutIn = Layout_; + using LayoutOut = Layout_; + CATLASS_HOST_DEVICE static LayoutOut GetWorkspaceLayout(LayoutIn &layout, uint32_t align) + { + if constexpr (std::is_same_v) { + return LayoutOut{layout.shape(0), layout.shape(1), RoundUp(layout.shape(1), align)}; + } else { + return LayoutOut{layout.shape(0), layout.shape(1), RoundUp(layout.shape(0), align)}; + } + } + static size_t GetWorkspaceSize(uint32_t rows, uint32_t cols, uint32_t align) + { + if constexpr (std::is_same_v) { + return static_cast(rows) * RoundUp(cols, align) * sizeof(Element); + } else { + return static_cast(cols) * RoundUp(rows, align) * sizeof(Element); + } + } + + CopyGm2Ub copyGm2Ub; + CopyUb2Gm copyUb2Gm; + + CATLASS_DEVICE + PaddingMatrixND(Arch::Resource &resource) + { + int64_t bufferOffset = 0; + for (uint32_t i = 0; i < BUFFER_NUM; i++) { + inputBuffer[i] = resource.ubBuf.template GetBufferByByte(bufferOffset * sizeof(Element)); + bufferOffset += COMPUTE_LENGTH; + } + } + + CATLASS_DEVICE + ComputeLayout GetPaddingComputeLayout(layout::RowMajor const &layout) + { + return ComputeLayout(layout.shape(0), layout.shape(1), layout.stride(0)); + } + + CATLASS_DEVICE + ComputeLayout GetPaddingComputeLayout(layout::ColumnMajor const &layout) + { + return ComputeLayout(layout.shape(1), layout.shape(0), layout.stride(1)); + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &dst, AscendC::GlobalTensor const &src, + Layout layoutDst, Layout layoutSrc) + { + ComputeLayout computeLayoutSrc = GetPaddingComputeLayout(layoutSrc); + ComputeLayout computeLayoutDst = GetPaddingComputeLayout(layoutDst); + + uint32_t aivNum = AscendC::GetBlockNum() * AscendC::GetSubBlockNum(); + uint32_t aivId = AscendC::GetBlockIdx(); + + // Each line is a tile. + uint32_t tilesNum = computeLayoutSrc.shape(0); + uint32_t tileLen = computeLayoutSrc.shape(1); + uint32_t paddingStride = computeLayoutDst.stride(0); + + uint32_t tilesPerAiv = tilesNum / aivNum; + uint32_t tileRemain = tilesNum % aivNum; + if (aivId < tileRemain) { + tilesPerAiv++; + } + uint32_t mIdx = aivId * tilesPerAiv; + if (aivId >= tileRemain) { + mIdx += tileRemain; + } + MatrixCoord blockOffset(mIdx, 0); + + AscendC::SetFlag(eventIds[0]); + AscendC::SetFlag(eventIds[1]); + uint32_t coreLoops{0}; + if (paddingStride > COMPUTE_LENGTH) { + // Handle the same tile on multiple loops. + uint32_t loopsPerTile = (tileLen + COMPUTE_LENGTH - 1) / COMPUTE_LENGTH; + coreLoops = tilesPerAiv * loopsPerTile; + for (uint32_t loopIdx = 0; loopIdx < coreLoops; ++loopIdx) { + uint32_t tileIdx = loopIdx / loopsPerTile; + uint32_t inTileLoopIdx = loopIdx % loopsPerTile; + MatrixCoord loopOffset(tileIdx, inTileLoopIdx * COMPUTE_LENGTH); + uint64_t gmSrcOffset = computeLayoutSrc.GetOffset(blockOffset + loopOffset); + uint32_t actualDataNum = COMPUTE_LENGTH; + if (tileLen - inTileLoopIdx * COMPUTE_LENGTH < COMPUTE_LENGTH) { + actualDataNum = tileLen - inTileLoopIdx * COMPUTE_LENGTH; + } + + AscendC::WaitFlag(eventIds[bufferIndex]); + ComputeLayout dstLayout = computeLayoutDst.GetTileLayout(MatrixCoord(1, actualDataNum)); + ComputeLayout srcLayout = computeLayoutSrc.GetTileLayout(MatrixCoord(1, actualDataNum)); + ComputeLayout &ubLayout = dstLayout; + + copyGm2Ub(inputBuffer[bufferIndex], src[gmSrcOffset], ubLayout, srcLayout); + AscendC::SetFlag(eventIds[bufferIndex]); + AscendC::WaitFlag(eventIds[bufferIndex]); + + uint64_t gmDstOffset = computeLayoutDst.GetOffset(blockOffset + loopOffset); + copyUb2Gm(dst[gmDstOffset], inputBuffer[bufferIndex], dstLayout, ubLayout); + AscendC::SetFlag(eventIds[bufferIndex]); + + bufferIndex = (bufferIndex + 1) % BUFFER_NUM; + } + } else { + // Handle multiple tile each loop. + uint32_t tilesPerLoop = COMPUTE_LENGTH / paddingStride; + coreLoops = (tilesPerAiv + tilesPerLoop - 1) / tilesPerLoop; + for (uint32_t loopIdx = 0; loopIdx < coreLoops; ++loopIdx) { + uint32_t tileIdx = loopIdx * tilesPerLoop; + MatrixCoord tileOffset(tileIdx, 0); + uint64_t gmSrcOffset = computeLayoutSrc.GetOffset(blockOffset + tileOffset); + uint32_t actualTilesNum = tilesPerLoop; + if (tilesPerAiv - tileIdx < tilesPerLoop) { + actualTilesNum = tilesPerAiv - tileIdx; + } + + AscendC::WaitFlag(eventIds[bufferIndex]); + ComputeLayout dstLayout = computeLayoutDst.GetTileLayout(MatrixCoord(actualTilesNum, tileLen)); + ComputeLayout srcLayout = computeLayoutSrc.GetTileLayout(MatrixCoord(actualTilesNum, tileLen)); + ComputeLayout &ubLayout = dstLayout; + + copyGm2Ub(inputBuffer[bufferIndex], src[gmSrcOffset], ubLayout, srcLayout); + AscendC::SetFlag(eventIds[bufferIndex]); + AscendC::WaitFlag(eventIds[bufferIndex]); + + uint64_t gmDstOffset = computeLayoutDst.GetOffset(blockOffset + tileOffset); + copyUb2Gm(dst[gmDstOffset], inputBuffer[bufferIndex], dstLayout, ubLayout); + AscendC::SetFlag(eventIds[bufferIndex]); + + bufferIndex = (bufferIndex + 1) % BUFFER_NUM; + } + } + AscendC::WaitFlag(eventIds[0]); + AscendC::WaitFlag(eventIds[1]); + } + + CATLASS_DEVICE + ~PaddingMatrixND() {} + +private: + static const uint32_t BUFFER_NUM = 2; + AscendC::LocalTensor inputBuffer[BUFFER_NUM]; + AscendC::TEventID eventIds[BUFFER_NUM] = {EVENT_ID0, EVENT_ID1}; + uint32_t bufferIndex{0}; + static_assert(BUFFER_NUM * COMPUTE_LENGTH * sizeof(Element) <= ArchTag::UB_SIZE, "Exceeding the UB space!"); + static_assert(std::is_same_v || std::is_same_v, + "Unsported layout for PaddingMatrixND!"); +}; + +// The PaddingBuilder structure can construct the required padding class by specifying the PaddingTag +// and the basic information of the matrix, thereby unifying the use of various paddings. +// Moreover, it allows for quick retrieval of the layout information after padding. +template +struct PaddingBuilder { + static_assert(DEPENDENT_FALSE, "Padding is not implemented for this layout"); +}; + +template +struct PaddingBuilder { + using LayoutAfterPadding = LayoutIn; + using Padding = void; +}; + +template +struct PaddingBuilder { + using LayoutAfterPadding = LayoutIn; + using Padding = Catlass::Gemm::Kernel::PaddingMatrixND; +}; + +template +struct PaddingBuilder { + using LayoutAfterPadding = std::conditional_t, layout::PaddingRowMajor, + layout::PaddingColumnMajor>; + using Padding = + Catlass::Gemm::Kernel::PaddingMatrixBlockND; +}; + +template +class PaddingMatmul +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using ElementA = typename BlockMmad::ElementA; + using ElementB = typename BlockMmad::ElementB; + using LayoutA = typename BlockMmad::LayoutA; + using LayoutB = typename BlockMmad::LayoutB; + + static const uint32_t COMPUTE_LENGTH_A = 96 * 1024 / sizeof(ElementA); + using PaddingA = PaddingMatrixND; + static const uint32_t COMPUTE_LENGTH_B = 96 * 1024 / sizeof(ElementB); + using PaddingB = PaddingMatrixND; + + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + + using BlockScheduler = BlockScheduler_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + GM_ADDR ptrA; + LayoutA layoutA; + GM_ADDR ptrB; + LayoutB layoutB; + GM_ADDR ptrC; + LayoutC layoutC; + GM_ADDR ptrWA; + LayoutA layoutWA; + GM_ADDR ptrWB; + LayoutB layoutWB; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, + GM_ADDR ptrC_, LayoutC layoutC_, GM_ADDR ptrWA_, LayoutA layoutWA_, GM_ADDR ptrWB_, LayoutB layoutWB_) + : problemShape(problemShape_), + ptrA(ptrA_), + layoutA(layoutA_), + ptrB(ptrB_), + layoutB(layoutB_), + ptrC(ptrC_), + layoutC(layoutC_), + ptrWA(ptrWA_), + layoutWA(layoutWA_), + ptrWB(ptrWB_), + layoutWB(layoutWB_) + {} + }; + + struct Arguments { + GemmCoord problemShape; + uint32_t align; + size_t elementSize; + GM_ADDR ptrA; + GM_ADDR ptrB; + GM_ADDR ptrC; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static layout::RowMajor GetWorkspaceLayout(layout::RowMajor layout, uint32_t align) + { + // prevent division of 0 + if (align == 0) { + return 0; + } + return layout::RowMajor(layout.shape(0), layout.shape(1), (layout.shape(1) + align - 1) / align * align); + } + + static layout::ColumnMajor GetWorkspaceLayout(layout::ColumnMajor layout, uint32_t align) + { + return layout::ColumnMajor(layout.shape(0), layout.shape(1), (layout.shape(0) + align - 1) / align * align); + } + + static size_t GetWorkspaceLen(layout::RowMajor layout) + { + return layout.shape(0) * layout.stride(0); + } + + static size_t GetWorkspaceLen(layout::ColumnMajor layout) + { + return layout.shape(1) * layout.stride(1); + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + GemmCoord problemShape = args.problemShape; + LayoutA layoutA{problemShape.m(), problemShape.k()}; + LayoutB layoutB{problemShape.k(), problemShape.n()}; + size_t sizeWA = GetWorkspaceLen(GetWorkspaceLayout(layoutA, args.align)) * args.elementSize; + size_t sizeWB = GetWorkspaceLen(GetWorkspaceLayout(layoutB, args.align)) * args.elementSize; + return sizeWA + sizeWB; + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + GemmCoord problemShape = args.problemShape; + uint32_t m = problemShape.m(); + uint32_t n = problemShape.n(); + uint32_t k = problemShape.k(); + LayoutA layoutA{m, k}; + LayoutB layoutB{k, n}; + LayoutC layoutC{m, n}; + size_t sizeWA = GetWorkspaceLen(GetWorkspaceLayout(layoutA, args.align)) * args.elementSize; + uint8_t *workspaceWB = workspace + sizeWA; + Params params{problemShape, + args.ptrA, + layoutA, + args.ptrB, + layoutB, + args.ptrC, + layoutC, + workspace, + GetWorkspaceLayout(layoutA, args.align), + workspaceWB, + GetWorkspaceLayout(layoutB, args.align)}; + return params; + } + + // Methods + CATLASS_DEVICE + PaddingMatmul() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + AscendC::GlobalTensor gmA; + AscendC::GlobalTensor gmWA; + gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(params.ptrA)); + gmWA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(params.ptrWA)); + PaddingA paddingA(resource); + paddingA(gmWA, gmA, params.layoutWA, params.layoutA); + + AscendC::GlobalTensor gmB; + AscendC::GlobalTensor gmWB; + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(params.ptrB)); + gmWB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(params.ptrWB)); + PaddingB paddingB(resource); + paddingB(gmWB, gmB, params.layoutWB, params.layoutB); + // 0x0 synchronization control between AI Core + Catlass::Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + + Catlass::Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishPadding); + + AscendC::PipeBarrier(); + } + + /// Executes matmul + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + Catlass::Arch::CrossCoreWaitFlag(flagAivFinishPadding); + + BlockScheduler matmulBlockScheduler(params.problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops(); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrWA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrWB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrC); + + BlockMmad blockMmad(resource); + + for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + // Compute block location + GemmCoord blockIdxCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockIdxCoord); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockIdxCoord.m() * L1TileShape::M, blockIdxCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockIdxCoord.k() * L1TileShape::K, blockIdxCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockIdxCoord.m() * L1TileShape::M, blockIdxCoord.n() * L1TileShape::N}; + int64_t gmOffsetA = params.layoutWA.GetOffset(offsetA); + int64_t gmOffsetB = params.layoutWB.GetOffset(offsetB); + int64_t gmOffsetC = params.layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + blockMmad(gmA[gmOffsetA], params.layoutWA, gmB[gmOffsetB], params.layoutWB, gmC[gmOffsetC], params.layoutC, + actualBlockShape); + } + + AscendC::PipeBarrier(); + } + +private: + static constexpr Arch::FlagID FLAG_AIV_FINISH_STORE = 0; + Arch::CrossCoreFlag flagAivFinishPadding{FLAG_AIV_FINISH_STORE}; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_PADDING_MATMUL_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/padding_splitk_matmul.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/padding_splitk_matmul.hpp new file mode 100644 index 00000000..f89fc601 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/padding_splitk_matmul.hpp @@ -0,0 +1,361 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_PADDING_SPLITK_MATMUL_HPP +#define CATLASS_GEMM_KERNEL_PADDING_SPLITK_MATMUL_HPP + +#include +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/epilogue/tile/copy_gm_to_ub.hpp" +#include "catlass/epilogue/tile/copy_ub_to_gm.hpp" +#include "catlass/gemm/kernel/padding_matmul.hpp" +#include "catlass/gemm/kernel/splitk_matmul.hpp" + +namespace Catlass::Gemm::Kernel { + +// Template for Matmul kernel. Compute C = A * B +template +class PaddingSplitkMatmul +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + static const uint32_t COMPUTE_LENGTH_A = 96 * 1024 / sizeof(ElementA); + using PaddingA = PaddingMatrixND; + static const uint32_t COMPUTE_LENGTH_B = 96 * 1024 / sizeof(ElementB); + using PaddingB = PaddingMatrixND; + + using BlockScheduler = BlockScheduler_; + using ReduceAdd = ReduceAdd_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + bool aNeedPadding; + bool bNeedPadding; + GM_ADDR ptrA; + LayoutA layoutA; + GM_ADDR ptrB; + LayoutB layoutB; + GM_ADDR ptrC; + LayoutC layoutC; + GM_ADDR ptrWA; + LayoutA layoutWA; + GM_ADDR ptrWB; + LayoutB layoutWB; + GM_ADDR ptrWC; + uint32_t splitkFactor = 1; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord const &problemShape_, bool aNeedPadding_, bool bNeedPadding_, GM_ADDR ptrA_, LayoutA layoutA_, + GM_ADDR ptrB_, LayoutB layoutB_, GM_ADDR ptrC_, LayoutC layoutC_, GM_ADDR ptrWA_, LayoutA layoutWA_, + GM_ADDR ptrWB_, LayoutB layoutWB_, GM_ADDR ptrWC_, uint32_t splitkFactor_) + : problemShape(problemShape_), + aNeedPadding(aNeedPadding_), + bNeedPadding(bNeedPadding_), + ptrA(ptrA_), + layoutA(layoutA_), + ptrB(ptrB_), + layoutB(layoutB_), + ptrC(ptrC_), + layoutC(layoutC_), + ptrWA(ptrWA_), + layoutWA(layoutWA_), + ptrWB(ptrWB_), + layoutWB(layoutWB_), + ptrWC(ptrWC_), + splitkFactor(splitkFactor_) + {} + }; + + struct Arguments { + GemmCoord problemShape; + uint32_t aicCoreNum; + uint32_t align; + bool aNeedPadding; + bool bNeedPadding; + size_t elementSize; + GM_ADDR ptrA; + GM_ADDR ptrB; + GM_ADDR ptrC; + }; + + static uint32_t GetSplitkFactor(uint32_t m, uint32_t n, uint32_t k, uint32_t aicCoreNum) + { + uint32_t maxSplitkFactor; + if (k <= 1024) { + // When k is less than or equal to 1024, it can be divided into at most 2 parts. + maxSplitkFactor = 2; + } else if (k <= 2048) { + // When k is less than or equal to 2048, it can be divided into at most 4 parts. + maxSplitkFactor = 4; + } else if (k <= 4096) { + // When k is less than or equal to 4096, it can be divided into at most 8 parts. + maxSplitkFactor = 8; + } else { + // else it can be divided into at most 16 parts. + maxSplitkFactor = 16; + } + uint32_t splitkFactor = 1; + uint32_t m0 = L1TileShape::M; + uint32_t n0 = L1TileShape::N; + uint32_t k0 = L1TileShape::K; + + uint32_t baseTilesCount = CeilDiv(m, m0) * CeilDiv(n, n0); + splitkFactor = std::min(aicCoreNum / baseTilesCount, maxSplitkFactor); + // Prevent the split factor form being less than 1 + splitkFactor = std::max(splitkFactor, static_cast(1)); + if (baseTilesCount < aicCoreNum) { + while (splitkFactor + 1 <= maxSplitkFactor && CeilDiv(baseTilesCount * splitkFactor, aicCoreNum) >= + CeilDiv(baseTilesCount, aicCoreNum) * splitkFactor) { + splitkFactor += 1; + } + } + // Ensure that splitkFactor is less than the number of base tiels in the k direction. + splitkFactor = std::min(CeilDiv(k, k0), splitkFactor); + // If k is very large, splitting k can lead to better cache utilization. + // If k is greater than 8192. + if (k > 8192) { + // split the k direction into at least 2 parts. + splitkFactor = std::max(splitkFactor, static_cast(2)); + } + // If k is greater than 32768. + if (k > 32768) { + // split the k direction into at least 4 parts. + splitkFactor = std::max(splitkFactor, static_cast(4)); + } + return splitkFactor; + } + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static layout::RowMajor GetWorkspaceLayout(layout::RowMajor layout, uint32_t align) + { + // prevent division of 0 + if (align == 0) { + return layout; + } + return layout::RowMajor(layout.shape(0), layout.shape(1), (layout.shape(1) + align - 1) / align * align); + } + + static layout::ColumnMajor GetWorkspaceLayout(layout::ColumnMajor layout, uint32_t align) + { + // prevent division of 0 + if (align == 0) { + return layout; + } + return layout::ColumnMajor(layout.shape(0), layout.shape(1), (layout.shape(0) + align - 1) / align * align); + } + + static size_t GetWorkspaceLen(layout::RowMajor layout) + { + return layout.shape(0) * layout.stride(0); + } + + static size_t GetWorkspaceLen(layout::ColumnMajor layout) + { + return layout.shape(1) * layout.stride(1); + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + GemmCoord problemShape = args.problemShape; + LayoutA layoutA{problemShape.m(), problemShape.k()}; + LayoutB layoutB{problemShape.k(), problemShape.n()}; + size_t sizeWA = GetWorkspaceLen(GetWorkspaceLayout(layoutA, args.align)) * args.elementSize; + size_t sizeWB = GetWorkspaceLen(GetWorkspaceLayout(layoutB, args.align)) * args.elementSize; + size_t sizeWC = + args.elementSize * args.problemShape.m() * args.problemShape.n() * + GetSplitkFactor(args.problemShape.m(), args.problemShape.n(), args.problemShape.k(), args.aicCoreNum); + return sizeWA + sizeWB + sizeWC; + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + LayoutA layoutA{args.problemShape.m(), args.problemShape.k()}; + LayoutB layoutB{args.problemShape.k(), args.problemShape.n()}; + LayoutC layoutC{args.problemShape.m(), args.problemShape.n()}; + + uint8_t *workspaceWA = nullptr; + uint8_t *workspaceWB = nullptr; + size_t sizeWA = 0; + size_t sizeWB = 0; + + if (args.aNeedPadding) { + workspaceWA = workspace; + sizeWA = GetWorkspaceLen(GetWorkspaceLayout(layoutA, args.align)) * args.elementSize; + } else { + workspaceWA = args.ptrA; + } + + if (args.bNeedPadding) { + workspaceWB = workspace + sizeWA; + sizeWB = GetWorkspaceLen(GetWorkspaceLayout(layoutB, args.align)) * args.elementSize; + } else { + workspaceWB = args.ptrB; + } + + uint8_t *workspaceWC = workspace + sizeWA + sizeWB; + + Params params{ + args.problemShape, + args.aNeedPadding, + args.bNeedPadding, + args.ptrA, + layoutA, + args.ptrB, + layoutB, + args.ptrC, + layoutC, + workspaceWA, + GetWorkspaceLayout(layoutA, args.align), + workspaceWB, + GetWorkspaceLayout(layoutB, args.align), + workspaceWC, + GetSplitkFactor(args.problemShape.m(), args.problemShape.n(), args.problemShape.k(), args.aicCoreNum)}; + return params; + } + + // Methods + CATLASS_DEVICE + PaddingSplitkMatmul() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + /// Executes one Matmul + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + if (params.aNeedPadding || params.bNeedPadding) { + Catlass::Arch::CrossCoreWaitFlag(flagAivFinishPadding); + } + + BlockScheduler matmulBlockScheduler( + params.problemShape, GemmCoord(L1TileShape::M, L1TileShape::N, L1TileShape::K), params.splitkFactor); + uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops(); + + Arch::Resource resource; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrWA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrWB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrWC); + + for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + // Compute block location + GemmCoord blockCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = + matmulBlockScheduler.GetActualBlockShape(blockCoord, matmulBlockScheduler.GetSplitkSliceIdx(loopIdx)); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + uint64_t gmOffsetA = params.layoutWA.GetOffset(offsetA); + uint64_t gmOffsetB = params.layoutWB.GetOffset(offsetB); + uint64_t gmOffsetC = params.layoutC.GetOffset(offsetC) + + static_cast(params.problemShape.m()) * + static_cast(params.problemShape.n()) * + static_cast(matmulBlockScheduler.GetSplitkSliceIdx(loopIdx)); + + // Compute block-scoped matrix multiply-add + blockMmad(gmA[gmOffsetA], params.layoutWA, gmB[gmOffsetB], params.layoutWB, gmC[gmOffsetC], params.layoutC, + actualBlockShape); + } + + Catlass::Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(flagAicFinish); + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + if (params.aNeedPadding) { + AscendC::GlobalTensor gmA; + AscendC::GlobalTensor gmWA; + gmA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(params.ptrA)); + gmWA.SetGlobalBuffer(reinterpret_cast<__gm__ ElementA *>(params.ptrWA)); + PaddingA paddingA(resource); + paddingA(gmWA, gmA, params.layoutWA, params.layoutA); + } + + if (params.bNeedPadding) { + AscendC::GlobalTensor gmB; + AscendC::GlobalTensor gmWB; + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(params.ptrB)); + gmWB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(params.ptrWB)); + PaddingB paddingB(resource); + paddingB(gmWB, gmB, params.layoutWB, params.layoutB); + } + + if (params.aNeedPadding || params.bNeedPadding) { + // 0x0 synchronization control between AI Core + Catlass::Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + Catlass::Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishPadding); + } + + // reduce add + using ElementOut = typename ReduceAdd::ElementOut; + using ElementAccumulator = typename ReduceAdd::ElementAccumulator; + + Catlass::Arch::CrossCoreWaitFlag(flagAicFinish); + Catlass::Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + + AscendC::GlobalTensor gmC; + AscendC::GlobalTensor gmWC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementOut *>(params.ptrC)); + gmWC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementAccumulator *>(params.ptrWC)); + ReduceAdd reduceAdd(resource); + reduceAdd(gmC, gmWC, + static_cast(params.problemShape.m()) * static_cast(params.problemShape.n()), + params.splitkFactor); + + AscendC::PipeBarrier(); + } + +private: + static constexpr Arch::FlagID FLAG_AIC_FINISH = 0; + Arch::CrossCoreFlag flagAicFinish{FLAG_AIC_FINISH}; + static constexpr Arch::FlagID FLAG_AIV_FINISH_STORE = 1; + Arch::CrossCoreFlag flagAivFinishPadding{FLAG_AIV_FINISH_STORE}; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_PADDING_SPLITK_MATMUL_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/quant_matmul.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/quant_matmul.hpp new file mode 100644 index 00000000..67cee5d0 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/quant_matmul.hpp @@ -0,0 +1,231 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_QUANT_MATMUL_HPP +#define CATLASS_GEMM_KERNEL_QUANT_MATMUL_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemm::Kernel { + +template +class QuantMatmul +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using BlockScheduler = BlockScheduler_; + + friend class AicFinishSync; + friend class AivWaitSync; + + struct AicFinishSync { + using MatmulKernel = QuantMatmul; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreSetFlagWithReverse<0x2, PIPE_FIX>(ptr->flagAicFinishStore); + } + + MatmulKernel *ptr; + }; + + struct AivWaitSync { + using MatmulKernel = QuantMatmul; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreWaitFlagWithReverse<0x2, PIPE_MTE3>(ptr->flagAicFinishStore); + } + + MatmulKernel *ptr; + }; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementD *ptrD; + LayoutD layoutD; + GM_ADDR ptrWorkspace; + + // Methods + CATLASS_DEVICE + Params() {} + + CATLASS_DEVICE + Params(GemmCoord problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, + GM_ADDR ptrScale_, LayoutScale layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale layoutPerTokenScale_, GM_ADDR ptrD_, LayoutD layoutD_, GM_ADDR ptrWorkspace_) + : problemShape(problemShape_), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(reinterpret_cast<__gm__ ElementD *>(ptrD_)), + layoutD(layoutD_), + ptrWorkspace(ptrWorkspace_) + {} + }; + + // Methods + CATLASS_DEVICE + QuantMatmul() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + /// Executes matmul + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + blockScheduler.Update(params.problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer(params.ptrB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + layout::RowMajor layoutC(params.problemShape.m(), params.problemShape.n()); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + + AicFinishSync aicFinishSync{this}; + + for (uint32_t loopIdx = coreIdx; loopIdx < coreLoops; loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); + int64_t gmOffsetB = params.layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmOffsetA], params.layoutA, gmB[gmOffsetB], params.layoutB, gmC[gmOffsetC], layoutC, + actualBlockShape, MakeCallback(&aicFinishSync)); + } else { + blockMmad(gmA[gmOffsetA], params.layoutA, gmB[gmOffsetB], params.layoutB, gmC[gmOffsetC], layoutC, + actualBlockShape); + aicFinishSync(); + } + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource); + + uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t coreNum = AscendC::GetBlockNum(); + uint32_t subCoreIndex = AscendC::GetSubBlockIdx(); + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + + AivWaitSync aicFinishSync{this}; + + LayoutC layoutC = LayoutC(params.problemShape.m(), params.problemShape.n()); + LayoutScale layoutScale = params.layoutScale; + LayoutPerTokenScale layoutPerTokenScale = + params.layoutPerTokenScale.GetTileLayout(params.problemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = params.layoutD.GetTileLayout(params.problemShape.GetCoordMN()); + + EpilogueParams epilogueParams{params.ptrScale, layoutScale, params.ptrPerTokenScale, + layoutPerTokenScale, params.ptrD, layoutD}; + + blockScheduler.Update(params.problemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + for (uint32_t loopIdx = coreIdx; loopIdx < coreLoops; loopIdx += coreNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + auto gmBlockC = gmC[layoutC.GetOffset(blockCoordMNK.GetCoordMN() * blockShapeMNK.GetCoordMN())]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + + blockEpilogue(blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, layoutBlockC, + MakeCallback(&aicFinishSync)); + } + + AscendC::PipeBarrier(); + } + +private: + static constexpr Arch::FlagID FLAG_AIC_FINISH_STORE = 0; + static constexpr Arch::FlagID RV_FLAG_AIC_FINISH_STORE = 1; + Arch::CrossCoreFlagWithReverse<> flagAicFinishStore{FLAG_AIC_FINISH_STORE, RV_FLAG_AIC_FINISH_STORE}; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_QUANT_MATMUL_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/quant_matmul_multistage_workspace.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/quant_matmul_multistage_workspace.hpp new file mode 100644 index 00000000..54d4b438 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/quant_matmul_multistage_workspace.hpp @@ -0,0 +1,313 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_QUANT_MATMUL_MULTISTAGE_WORKSPACE_HPP +#define CATLASS_GEMM_KERNEL_QUANT_MATMUL_MULTISTAGE_WORKSPACE_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemm::Kernel { + +template +class QuantMatmulMultiStageWorkspace +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementD *ptrD; + LayoutD layoutD; + GM_ADDR ptrWorkspace; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, + GM_ADDR ptrScale_, LayoutScale layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale layoutPerTokenScale_, GM_ADDR ptrD_, LayoutD layoutD_, GM_ADDR ptrWorkspace_) + : problemShape(problemShape_), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(reinterpret_cast<__gm__ ElementD *>(ptrD_)), + layoutD(layoutD_), + ptrWorkspace(ptrWorkspace_) + {} + }; + + struct Arguments { + GemmCoord problemShape; + uint32_t aicCoreNum; + uint8_t *ptrA; + uint8_t *ptrB; + uint8_t *ptrScale; + uint8_t *ptrPerTokenScale; + uint8_t *ptrD; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + size_t lenWorkspace = static_cast(L1TileShape::M) * L1TileShape::N * args.aicCoreNum * WORKSPACE_STAGES; + size_t sizeWorkspace = lenWorkspace * sizeof(uint32_t); + return sizeWorkspace; + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + uint32_t m = args.problemShape.m(); + uint32_t n = args.problemShape.n(); + uint32_t k = args.problemShape.k(); + LayoutA layoutA{m, k}; + LayoutB layoutB{k, n}; + LayoutScale layoutScale{n}; + LayoutPerTokenScale layoutPerTokenScale{m}; + LayoutD layoutD{m, n}; + Params params{ + args.problemShape, args.ptrA, layoutA, args.ptrB, layoutB, args.ptrScale, layoutScale, + args.ptrPerTokenScale, layoutPerTokenScale, args.ptrD, layoutD, workspace}; + return params; + } + + // Methods + CATLASS_DEVICE + QuantMatmulMultiStageWorkspace() + { + Arch::FlagID flagId = 0; + for (uint32_t stageId = 0; stageId < WORKSPACE_STAGES; ++stageId) { + flagAicFinishStoreList[stageId] = Arch::CrossCoreFlag(flagId++); + flagAivFinishComputeList[stageId] = Arch::CrossCoreFlag(flagId++); + aicWaitFuncList[stageId] = {this, stageId}; + aicSetFuncList[stageId] = {this, stageId}; + } + } + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + blockScheduler.Update(params.problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer(params.ptrB); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t stageUsed = 0; + + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = coreIdx; loopIdx < coreLoops; loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + Callback callbackBeforeFixpipe{}; + if (stageUsed == WORKSPACE_STAGES) { + callbackBeforeFixpipe = MakeCallback(&aicWaitFuncList[stageId]); + } else { + ++stageUsed; + } + Callback callbackAfterFixpipe = MakeCallback(&aicSetFuncList[stageId]); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); + int64_t gmOffsetB = params.layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmOffsetA], params.layoutA, gmB[gmOffsetB], params.layoutB, gmC[gmOffsetC], layoutC, + actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe); + } else { + callbackBeforeFixpipe(); + blockMmad(gmA[gmOffsetA], params.layoutA, gmB[gmOffsetB], params.layoutB, gmC[gmOffsetC], layoutC, + actualBlockShape); + callbackAfterFixpipe(); + } + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + while (stageUsed > 0) { + uint32_t aivComputeStageId = + (stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed); + Arch::CrossCoreWaitFlag(flagAivFinishComputeList[aivComputeStageId]); + --stageUsed; + } + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource); + + uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t coreNum = AscendC::GetBlockNum(); + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + + LayoutScale layoutScale = params.layoutScale; + LayoutPerTokenScale layoutPerTokenScale = + params.layoutPerTokenScale.GetTileLayout(params.problemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = params.layoutD.GetTileLayout(params.problemShape.GetCoordMN()); + + EpilogueParams epilogueParams{params.ptrScale, layoutScale, params.ptrPerTokenScale, + layoutPerTokenScale, params.ptrD, layoutD}; + + blockScheduler.Update(params.problemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + for (uint32_t loopIdx = coreIdx; loopIdx < coreLoops; loopIdx += coreNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + auto gmBlockC = gmC[gmOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + + Arch::CrossCoreWaitFlag(flagAicFinishStoreList[stageId]); + blockEpilogue(blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, layoutBlockC); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishComputeList[stageId]); + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + AscendC::PipeBarrier(); + } + +private: + friend struct AicWaitFunc; + friend struct AicSetFunc; + + struct AicWaitFunc { + using MatmulKernel = QuantMatmulMultiStageWorkspace; + + CATLASS_DEVICE + AicWaitFunc() = default; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreWaitFlag(ptr->flagAivFinishComputeList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + struct AicSetFunc { + using MatmulKernel = QuantMatmulMultiStageWorkspace; + + CATLASS_DEVICE + AicSetFunc() = default; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(ptr->flagAicFinishStoreList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + Arch::CrossCoreFlag flagAicFinishStoreList[WORKSPACE_STAGES]; + Arch::CrossCoreFlag flagAivFinishComputeList[WORKSPACE_STAGES]; + + AicWaitFunc aicWaitFuncList[WORKSPACE_STAGES]; + AicSetFunc aicSetFuncList[WORKSPACE_STAGES]; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_QUANT_MATMUL_MULTISTAGE_WORKSPACE_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/splitk_matmul.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/splitk_matmul.hpp new file mode 100644 index 00000000..61cf8a2a --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/splitk_matmul.hpp @@ -0,0 +1,373 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_SPLITK_MATMUL_HPP +#define CATLASS_GEMM_KERNEL_SPLITK_MATMUL_HPP + +#include +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemm::Kernel { + +template +struct ReduceAdd { + using ArchTag = ArchTag_; + using ElementAccumulator = ElementAccumulator_; + using ElementOut = ElementOut_; + + CATLASS_DEVICE + ReduceAdd(Arch::Resource &resource) + { + int64_t bufferOffset = 0; + for (uint32_t i = 0; i < BUFFER_NUM; i++) { + inputBuffer[i] = resource.ubBuf.template GetBufferByByte(bufferOffset); + bufferOffset += COMPUTE_LENGTH * sizeof(ElementAccumulator); + accumulatorBuffer[i] = resource.ubBuf.template GetBufferByByte(bufferOffset); + bufferOffset += COMPUTE_LENGTH * sizeof(ElementAccumulator); + outputBuffer[i] = resource.ubBuf.template GetBufferByByte(bufferOffset); + bufferOffset += COMPUTE_LENGTH * sizeof(ElementOut); + } + } + + CATLASS_DEVICE + void Gm2Ub(AscendC::LocalTensor const &dst, + AscendC::GlobalTensor const &src, uint32_t dataNum) + { + AscendC::DataCopyExtParams dataCopyParams(1, dataNum * sizeof(ElementAccumulator), 0, 0, 0); + AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); + AscendC::DataCopyPad(dst, src, dataCopyParams, padParams); + } + + CATLASS_DEVICE + void Ub2Gm(AscendC::GlobalTensor const &dst, AscendC::LocalTensor const &src, + uint32_t dataNum) + { + AscendC::DataCopyExtParams dataCopyParams(1, dataNum * sizeof(ElementOut), 0, 0, 0); + AscendC::DataCopyPad(dst, src, dataCopyParams); + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &dst, AscendC::GlobalTensor const &src, + uint64_t elementCount, uint32_t splitkFactor) + { + // The vec mte processes 256 bytes of data at a time. + constexpr uint32_t ELE_PER_VECOTR_BLOCK = 256 / sizeof(ElementAccumulator); + uint32_t aivNum = AscendC::GetBlockNum() * AscendC::GetSubBlockNum(); + uint32_t aivId = AscendC::GetBlockIdx(); + uint64_t taskPerAiv = + (elementCount / aivNum + ELE_PER_VECOTR_BLOCK - 1) / ELE_PER_VECOTR_BLOCK * ELE_PER_VECOTR_BLOCK; + if (taskPerAiv == 0) taskPerAiv = ELE_PER_VECOTR_BLOCK; + uint32_t tileLen; + if (taskPerAiv > COMPUTE_LENGTH) { + tileLen = COMPUTE_LENGTH; + } else { + tileLen = taskPerAiv; + } + + AscendC::SetFlag(inputEventIds[0]); + AscendC::SetFlag(inputEventIds[1]); + AscendC::SetFlag(outputEventIds[0]); + AscendC::SetFlag(outputEventIds[1]); + AscendC::SetFlag(accumulatorEventIds[0]); + AscendC::SetFlag(accumulatorEventIds[1]); + + uint32_t loops = (elementCount + tileLen - 1) / tileLen; + for (uint32_t loopIdx = aivId; loopIdx < loops; loopIdx += aivNum) { + uint32_t actualTileLen = tileLen; + if (loopIdx == loops - 1) { + actualTileLen = elementCount - loopIdx * tileLen; + } + + AscendC::WaitFlag(accumulatorEventIds[bufferIndex]); + Gm2Ub(accumulatorBuffer[bufferIndex], src[loopIdx * tileLen], actualTileLen); + AscendC::SetFlag(accumulatorEventIds[bufferIndex]); + AscendC::WaitFlag(accumulatorEventIds[bufferIndex]); + + for (uint32_t sliceIdx = 1; sliceIdx < splitkFactor; ++sliceIdx) { + AscendC::WaitFlag(inputEventIds[bufferIndex]); + Gm2Ub(inputBuffer[bufferIndex], src[sliceIdx * elementCount + loopIdx * tileLen], actualTileLen); + AscendC::SetFlag(inputEventIds[bufferIndex]); + AscendC::WaitFlag(inputEventIds[bufferIndex]); + + AscendC::Add(accumulatorBuffer[bufferIndex], accumulatorBuffer[bufferIndex], inputBuffer[bufferIndex], + actualTileLen); + AscendC::SetFlag(inputEventIds[bufferIndex]); + } + AscendC::PipeBarrier(); + + AscendC::WaitFlag(outputEventIds[bufferIndex]); + if constexpr (!std::is_same_v) { + if constexpr (std::is_same_v) { + AscendC::Cast(outputBuffer[bufferIndex], accumulatorBuffer[bufferIndex], + AscendC::RoundMode::CAST_NONE, actualTileLen); + } else { + AscendC::Cast(outputBuffer[bufferIndex], accumulatorBuffer[bufferIndex], + AscendC::RoundMode::CAST_RINT, actualTileLen); + } + } else { + AscendC::DataCopy(outputBuffer[bufferIndex], accumulatorBuffer[bufferIndex], tileLen); + } + AscendC::SetFlag(accumulatorEventIds[bufferIndex]); + + AscendC::SetFlag(outputEventIds[bufferIndex]); + AscendC::WaitFlag(outputEventIds[bufferIndex]); + Ub2Gm(dst[loopIdx * tileLen], outputBuffer[bufferIndex], actualTileLen); + AscendC::SetFlag(outputEventIds[bufferIndex]); + + bufferIndex = (bufferIndex + 1) % BUFFER_NUM; + } + + AscendC::WaitFlag(inputEventIds[0]); + AscendC::WaitFlag(inputEventIds[1]); + AscendC::WaitFlag(outputEventIds[0]); + AscendC::WaitFlag(outputEventIds[1]); + AscendC::WaitFlag(accumulatorEventIds[0]); + AscendC::WaitFlag(accumulatorEventIds[1]); + } + +private: + static const uint32_t BUFFER_NUM = 2; + AscendC::LocalTensor inputBuffer[BUFFER_NUM]; + AscendC::LocalTensor accumulatorBuffer[BUFFER_NUM]; + AscendC::LocalTensor outputBuffer[BUFFER_NUM]; + AscendC::TEventID inputEventIds[BUFFER_NUM] = {EVENT_ID0, EVENT_ID1}; + AscendC::TEventID accumulatorEventIds[BUFFER_NUM] = {EVENT_ID2, EVENT_ID3}; + AscendC::TEventID outputEventIds[BUFFER_NUM] = {EVENT_ID0, EVENT_ID1}; + uint32_t bufferIndex{0}; + static_assert(BUFFER_NUM * COMPUTE_LENGTH * sizeof(ElementAccumulator) * 2 + + BUFFER_NUM * COMPUTE_LENGTH * sizeof(ElementOut) <= + ArchTag::UB_SIZE, + "Exceeding the UB space!"); +}; + +// Template for Matmul kernel. Compute C = A * B +template +class SplitkMatmul +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockScheduler = BlockScheduler_; + using ReduceAdd = ReduceAdd_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + GM_ADDR ptrA; + LayoutA layoutA; + GM_ADDR ptrB; + LayoutB layoutB; + GM_ADDR ptrC; + LayoutC layoutC; + GM_ADDR ptrWorkspace; + uint32_t splitkFactor = 1; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, + GM_ADDR ptrC_, LayoutC layoutC_, GM_ADDR ptrWorkspace_, uint32_t splitkFactor_) + : problemShape(problemShape_), + ptrA(ptrA_), + layoutA(layoutA_), + ptrB(ptrB_), + layoutB(layoutB_), + ptrC(ptrC_), + layoutC(layoutC_), + ptrWorkspace(ptrWorkspace_), + splitkFactor(splitkFactor_) + {} + }; + + struct Arguments { + GemmCoord problemShape; + uint32_t aicCoreNum; + size_t workspaceElementSize; + GM_ADDR ptrA; + GM_ADDR ptrB; + GM_ADDR ptrC; + }; + + static uint32_t GetSplitkFactor(uint32_t m, uint32_t n, uint32_t k, uint32_t aicCoreNum) + { + uint32_t maxSplitkFactor; + if (k <= 1024) { + // When k is less than or equal to 1024, it can be divided into at most 2 parts. + maxSplitkFactor = 2; + } else if (k <= 2048) { + // When k is less than or equal to 2048, it can be divided into at most 4 parts. + maxSplitkFactor = 4; + } else if (k <= 4096) { + // When k is less than or equal to 4096, it can be divided into at most 8 parts. + maxSplitkFactor = 8; + } else { + // else it can be divided into at most 16 parts. + maxSplitkFactor = 16; + } + uint32_t splitkFactor = 1; + uint32_t m0 = L1TileShape::M; + uint32_t n0 = L1TileShape::N; + uint32_t k0 = L1TileShape::K; + + uint32_t baseTilesCount = CeilDiv(m, m0) * CeilDiv(n, n0); + splitkFactor = std::min(aicCoreNum / baseTilesCount, maxSplitkFactor); + // Prevent the split factor form being less than 1 + splitkFactor = std::max(splitkFactor, static_cast(1)); + if (baseTilesCount < aicCoreNum) { + while (splitkFactor + 1 <= maxSplitkFactor && CeilDiv(baseTilesCount * splitkFactor, aicCoreNum) >= + CeilDiv(baseTilesCount, aicCoreNum) * splitkFactor) { + splitkFactor += 1; + } + } + // Ensure that splitkFactor is less than the number of base tiels in the k direction. + splitkFactor = std::min(CeilDiv(k, k0), splitkFactor); + // If k is very large, splitting k can lead to better cache utilization. + // If k is greater than 8192. + if (k > 8192) { + // split the k direction into at least 2 parts. + splitkFactor = std::max(splitkFactor, static_cast(2)); + } + // If k is greater than 32768. + if (k > 32768) { + // split the k direction into at least 4 parts. + splitkFactor = std::max(splitkFactor, static_cast(4)); + } + return splitkFactor; + } + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + return args.workspaceElementSize * args.problemShape.m() * args.problemShape.n() * + GetSplitkFactor(args.problemShape.m(), args.problemShape.n(), args.problemShape.k(), args.aicCoreNum); + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + LayoutA layoutA{args.problemShape.m(), args.problemShape.k()}; + LayoutB layoutB{args.problemShape.k(), args.problemShape.n()}; + LayoutC layoutC{args.problemShape.m(), args.problemShape.n()}; + Params params{ + args.problemShape, + args.ptrA, + layoutA, + args.ptrB, + layoutB, + args.ptrC, + layoutC, + workspace, + GetSplitkFactor(args.problemShape.m(), args.problemShape.n(), args.problemShape.k(), args.aicCoreNum)}; + return params; + } + + // Methods + CATLASS_DEVICE + SplitkMatmul() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + /// Executes one Matmul + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler matmulBlockScheduler( + params.problemShape, GemmCoord(L1TileShape::M, L1TileShape::N, L1TileShape::K), params.splitkFactor); + uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops(); + + Arch::Resource resource; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrB); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrWorkspace); + + for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + // Compute block location + GemmCoord blockCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = + matmulBlockScheduler.GetActualBlockShape(blockCoord, matmulBlockScheduler.GetSplitkSliceIdx(loopIdx)); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; + uint64_t gmOffsetA = params.layoutA.GetOffset(offsetA); + uint64_t gmOffsetB = params.layoutB.GetOffset(offsetB); + uint64_t gmOffsetC = params.layoutC.GetOffset(offsetC) + + static_cast(params.problemShape.m()) * + static_cast(params.problemShape.n()) * + static_cast(matmulBlockScheduler.GetSplitkSliceIdx(loopIdx)); + + // Compute block-scoped matrix multiply-add + blockMmad(gmA[gmOffsetA], params.layoutA, gmB[gmOffsetB], params.layoutB, gmC[gmOffsetC], params.layoutC, + actualBlockShape); + } + + Catlass::Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(flagAicFinish); + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + using ElementOut = typename ReduceAdd::ElementOut; + using ElementAccumulator = typename ReduceAdd::ElementAccumulator; + + Catlass::Arch::CrossCoreWaitFlag(flagAicFinish); + Catlass::Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + + AscendC::GlobalTensor gmC; + AscendC::GlobalTensor gmWorkspace; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementOut *>(params.ptrC)); + gmWorkspace.SetGlobalBuffer(reinterpret_cast<__gm__ ElementAccumulator *>(params.ptrWorkspace)); + ReduceAdd reduceAdd(resource); + reduceAdd(gmC, gmWorkspace, + static_cast(params.problemShape.m()) * static_cast(params.problemShape.n()), + params.splitkFactor); + + AscendC::PipeBarrier(); + } + +private: + static constexpr Arch::FlagID FLAG_AIC_FINISH = 0; + Arch::CrossCoreFlag flagAicFinish{FLAG_AIC_FINISH}; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_SPLITK_MATMUL_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/w8a16_matmul.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/w8a16_matmul.hpp new file mode 100644 index 00000000..eafccdf1 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/kernel/w8a16_matmul.hpp @@ -0,0 +1,237 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_KERNEL_W8A16_MATMUL_HPP +#define CATLASS_GEMM_KERNEL_W8A16_MATMUL_HPP + +#include "catlass/catlass.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/epilogue/tile/copy_gm_to_ub.hpp" +#include "catlass/epilogue/tile/copy_ub_to_gm.hpp" +#include "catlass/gemm/kernel/padding_matmul.hpp" + +namespace Catlass::Gemm::Kernel { + +template +class W8A16Matmul +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using ElementA = typename BlockMmad::ElementA; + using ElementB = typename BlockMmad::ElementB; + using LayoutA = typename BlockMmad::LayoutA; + using LayoutB = typename BlockMmad::LayoutB; + + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + + using BlockScheduler = BlockScheduler_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + GM_ADDR ptrA; + LayoutA layoutA; + GM_ADDR ptrB; + LayoutB layoutB; + GM_ADDR ptrC; + LayoutC layoutC; + GM_ADDR ptrWksp; + half deqScalar; + half deqZeroPoint; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, + GM_ADDR ptrC_, LayoutC layoutC_, GM_ADDR ptrWksp_, half deqScalar_, half deqZeroPoint_) + : problemShape(problemShape_), + ptrA(ptrA_), + layoutA(layoutA_), + ptrB(ptrB_), + layoutB(layoutB_), + ptrC(ptrC_), + layoutC(layoutC_), + ptrWksp(ptrWksp_), + deqScalar(deqScalar_), + deqZeroPoint(deqZeroPoint_) + {} + }; + + struct Arguments { + GemmCoord problemShape; + uint32_t aicCoreNum; + size_t elementSize; + GM_ADDR ptrA; + GM_ADDR ptrB; + GM_ADDR ptrC; + half deqScalar; + half deqZeroPoint; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + // Calculate workspace size, using double buffer + size_t lenWorkspace = static_cast(L1TileShape::N) * L1TileShape::K * args.aicCoreNum * BUFFER_NUM; + size_t sizeWorkspace = lenWorkspace * args.elementSize; + return sizeWorkspace; + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + LayoutA layoutA{args.problemShape.m(), args.problemShape.k()}; + LayoutB layoutB{args.problemShape.k(), args.problemShape.n()}; + LayoutC layoutC{args.problemShape.m(), args.problemShape.n()}; + + Params params{args.problemShape, args.ptrA, layoutA, args.ptrB, layoutB, + args.ptrC, layoutC, workspace, args.deqScalar, args.deqZeroPoint}; + return params; + } + + // Methods + CATLASS_DEVICE + W8A16Matmul() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler matmulBlockScheduler(params.problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops(); + + // Represent the full gm + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ int8_t *)params.ptrB); + AscendC::GlobalTensor gmBWksp; + gmBWksp.SetGlobalBuffer((__gm__ ElementB *)params.ptrWksp); + + BlockMmad blockMmad(resource); + + GemmCoord blockIdxCoord; + GemmCoord actualBlockShape; + GemmCoord nextBlockIdCoord; + GemmCoord nextActualBlockShape; + + for (uint32_t loopIdx = AscendC::GetBlockIdx() / 2; loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + bool isFirstBlock = (loopIdx == AscendC::GetBlockIdx() / 2); + bool hasNextBlock = false; + + // Compute block location + if (isFirstBlock) { + blockIdxCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); + actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockIdxCoord); + } else { + blockIdxCoord = nextBlockIdCoord; + actualBlockShape = nextActualBlockShape; + } + if (loopIdx + AscendC::GetBlockNum() < coreLoops) { + hasNextBlock = true; + nextBlockIdCoord = matmulBlockScheduler.GetBlockCoord(loopIdx + AscendC::GetBlockNum()); + nextActualBlockShape = matmulBlockScheduler.GetActualBlockShape(nextBlockIdCoord); + } + + // Compute initial location in logical coordinates + MatrixCoord offsetB{blockIdxCoord.k() * L1TileShape::K, blockIdxCoord.n() * L1TileShape::N}; + int64_t gmOffsetB = params.layoutB.GetOffset(offsetB); + MatrixCoord offsetNextB{nextBlockIdCoord.k() * L1TileShape::K, nextBlockIdCoord.n() * L1TileShape::N}; + int64_t gmOffsetNextB = params.layoutB.GetOffset(offsetNextB); + int64_t gmOffsetBWksp = (AscendC::GetBlockIdx() / 2) * L1TileShape::K * L1TileShape::N * 2; + + // Compute block-scoped matrix multiply-add + blockMmad(gmB[gmOffsetB], params.layoutB, gmB[gmOffsetNextB], gmBWksp[gmOffsetBWksp], actualBlockShape, + nextActualBlockShape, isFirstBlock, hasNextBlock, params.deqScalar, params.deqZeroPoint); + } + + AscendC::PipeBarrier(); + } + + /// Executes matmul + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler matmulBlockScheduler(params.problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops(); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrWksp); + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrC); + + BlockMmad blockMmad(resource); + + GemmCoord blockIdxCoord; + GemmCoord actualBlockShape; + GemmCoord nextBlockIdCoord; + GemmCoord nextActualBlockShape; + + for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + bool isFirstBlock = (loopIdx == AscendC::GetBlockIdx()); + bool hasNextBlock = false; + + // Compute block location + if (isFirstBlock) { + blockIdxCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); + actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockIdxCoord); + } else { + blockIdxCoord = nextBlockIdCoord; + actualBlockShape = nextActualBlockShape; + } + if (loopIdx + AscendC::GetBlockNum() < coreLoops) { + hasNextBlock = true; + nextBlockIdCoord = matmulBlockScheduler.GetBlockCoord(loopIdx + AscendC::GetBlockNum()); + nextActualBlockShape = matmulBlockScheduler.GetActualBlockShape(nextBlockIdCoord); + } + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockIdxCoord.m() * L1TileShape::M, blockIdxCoord.k() * L1TileShape::K}; + MatrixCoord offsetC{blockIdxCoord.m() * L1TileShape::M, blockIdxCoord.n() * L1TileShape::N}; + int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); + int64_t gmOffsetB = AscendC::GetBlockIdx() * L1TileShape::K * L1TileShape::N * 2; + int64_t gmOffsetC = params.layoutC.GetOffset(offsetC); + + MatrixCoord offsetNextA{nextBlockIdCoord.m() * L1TileShape::M, nextBlockIdCoord.k() * L1TileShape::K}; + int64_t gmOffsetNextA = params.layoutA.GetOffset(offsetNextA); + // Compute block-scoped matrix multiply-add + blockMmad(gmA[gmOffsetA], params.layoutA, gmB[gmOffsetB], gmC[gmOffsetC], params.layoutC, + gmA[gmOffsetNextA], actualBlockShape, nextActualBlockShape, isFirstBlock, hasNextBlock); + } + + AscendC::PipeBarrier(); + } + +private: + static constexpr Arch::FlagID FLAG_AIV_FINISH_STORE = 0; + static const uint32_t BUFFER_NUM = 2; + Arch::CrossCoreFlag flagAivFinishPadding{FLAG_AIV_FINISH_STORE}; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // CATLASS_GEMM_KERNEL_W8A16_MATMUL_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_gm_to_l1.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_gm_to_l1.hpp new file mode 100644 index 00000000..371b3eb3 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_gm_to_l1.hpp @@ -0,0 +1,1490 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_TILE_COPY_GM_TO_L1_HPP +#define CATLASS_GEMM_TILE_COPY_GM_TO_L1_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/gemm/gemm_type.hpp" +#include "catlass/gemm/tile/tile_copy_tla.hpp" +#include "tla/tensor.hpp" + +namespace Catlass::Gemm::Tile { + +template +struct CopyGmToL1 { + static_assert(DEPENDENT_FALSE, "Unsupported copy gm to l1, can not find the specialization."); +}; + +template +struct CopyGmToL1IntervalDataCopy { + static_assert(DEPENDENT_FALSE, "Unsupported copy gm to l1, can not find the specialization."); +}; + +template +struct CopyGmToL1GMMPTD { + static_assert(DEPENDENT_FALSE, "Unsupported copy gm to l1, can not find the specialization."); +}; + +/// Partial specialization for AtlasA2, RowMajor in and zN out. +template +struct CopyGmToL1GMMPTD> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1GMMPTD() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (layoutSrc.shape(0) == 1) { + // If the number of matrix rows is 1, the regular interval-based DataCopy interface can be used instead of + // the ND2NZ DataCopy interface, resulting in higher transfer efficiency. + AscendC::DataCopyParams dataCopyParams(CeilDiv(layoutSrc.shape(1), layoutDst.shape(2)), + layoutDst.shape(2) / ELE_NUM_PER_C0, 0, + (layoutDst.stride(3) - layoutDst.shape(2)) / ELE_NUM_PER_C0); + AscendC::DataCopy(dstTensor, srcTensor, dataCopyParams); + } else { + if (layoutSrc.stride(0) < STRIDE_LIMIT) { + if (layoutSrc.shape(1) != ELE_NUM_PER_C0 || layoutSrc.stride(0) != ELE_NUM_PER_C0) { + intriParams.nValue = layoutSrc.shape(0); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + // If the matrix has ELE_NUM_PER_C0 columns and a stride of ELE_NUM_PER_C0, it follows a row-major + // layout in L1, allowing the use of the standard contiguous DataCopy interface for more efficient + // transfers. + AscendC::DataCopy(dstTensor, srcTensor, layoutSrc.shape(0) * layoutSrc.shape(1)); + } + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(0)], intriParams); + } + } + } + } + + // layoutSrc must be the layout of one of the src matrices + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc, uint32_t ndNum, uint32_t srcNdMatrixStride, + uint32_t dstNzNStride, uint32_t dstNzMatrixStride, uint32_t dstNzC0Stride) + { + AscendC::Nd2NzParams intriParams; + + intriParams.nValue = layoutSrc.shape(0); + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = dstNzNStride; + intriParams.dstNzC0Stride = dstNzC0Stride; + if (srcNdMatrixStride < STRIDE_LIMIT) { + intriParams.ndNum = ndNum; + intriParams.srcNdMatrixStride = srcNdMatrixStride; + intriParams.dstNzMatrixStride = dstNzMatrixStride; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.ndNum = 1; + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzMatrixStride = 0; + for (uint32_t i = 0; i < ndNum; i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * srcNdMatrixStride], intriParams); + } + } + } +}; + +//////////////////////////////////////// +/// Using the standard strided DataCopy interface to implement nd2nz +/// transfer may achieve higher data transfer efficiency when the data block shape is short and wide +/// Partial specialization for AtlasA2, half, RowMajor in and zN out. +template <> +struct CopyGmToL1IntervalDataCopy> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::RowMajor; + using Element = half; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1IntervalDataCopy() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + for (int i = 0; i < layoutSrc.shape(0); ++i) { + AscendC::DataCopyParams dataCopyParams(CeilDiv(layoutSrc.shape(1), layoutDst.shape(2)), + layoutDst.shape(2) / ELE_NUM_PER_C0, 0, + (layoutDst.stride(3) - layoutDst.shape(2)) / ELE_NUM_PER_C0); + AscendC::DataCopy(dstTensor[i * layoutDst.shape(2)], srcTensor[i * layoutSrc.stride(0)], dataCopyParams); + } + } +}; + +/// Partial specialization for AtlasA2, half, PaddingRowMajor in and zN out. +/// Using the standard strided DataCopy interface to implement nd2nz +/// transfer may achieve higher data transfer efficiency when the data block shape is short and wide +template <> +struct CopyGmToL1IntervalDataCopy> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::PaddingRowMajor; + using Element = half; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1IntervalDataCopy() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + for (int i = 0; i < layoutSrc.orgShape(0); ++i) { + AscendC::DataCopyParams dataCopyParams(CeilDiv(layoutSrc.orgShape(1), layoutDst.shape(2)), + layoutDst.shape(2) / ELE_NUM_PER_C0, 0, + (layoutDst.stride(3) - layoutDst.shape(2)) / ELE_NUM_PER_C0); + AscendC::DataCopy(dstTensor[i * layoutDst.shape(2)], srcTensor[i * layoutSrc.stride(0)], dataCopyParams); + } + } +}; + +/// Partial specialization for AtlasA2, half, ColumnMajor in and zN out. +/// Using the standard strided DataCopy interface to implement nd2nz +/// transfer may achieve higher data transfer efficiency when the data block shape is tall and narrow +template <> +struct CopyGmToL1IntervalDataCopy> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::ColumnMajor; + using Element = half; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1IntervalDataCopy() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + for (int i = 0; i < layoutSrc.shape(1); ++i) { + AscendC::DataCopyParams dataCopyParams(CeilDiv(layoutSrc.shape(0), layoutDst.shape(0)), + layoutDst.shape(0) / ELE_NUM_PER_C0, 0, + (layoutDst.stride(1) - layoutDst.shape(0)) / ELE_NUM_PER_C0); + AscendC::DataCopy(dstTensor[i * layoutDst.shape(0)], srcTensor[i * layoutSrc.stride(1)], dataCopyParams); + } + } +}; + +/// Partial specialization for AtlasA2, half, PaddingColumnMajor in and zN out. +/// Using the standard strided DataCopy interface to implement nd2nz +/// transfer may achieve higher data transfer efficiency when the data block shape is tall and narrow +template <> +struct CopyGmToL1IntervalDataCopy> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::PaddingColumnMajor; + using Element = half; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1IntervalDataCopy() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + for (int i = 0; i < layoutSrc.orgShape(1); ++i) { + AscendC::DataCopyParams dataCopyParams(CeilDiv(layoutSrc.orgShape(0), layoutDst.shape(0)), + layoutDst.shape(0) / ELE_NUM_PER_C0, 0, + (layoutDst.stride(1) - layoutDst.shape(0)) / ELE_NUM_PER_C0); + AscendC::DataCopy(dstTensor[i * layoutDst.shape(0)], srcTensor[i * layoutSrc.stride(2)], dataCopyParams); + } + } +}; + +/// new add gemm +template +struct CopyGmToL1, + Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (layoutSrc.stride(0) < STRIDE_LIMIT) { + intriParams.nValue = layoutSrc.shape(0); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(0)], intriParams); + } + } + } +}; + +template +struct CopyGmToL1, + Gemm::GemmType> { + using LayoutDst = layout::zZ; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + uint32_t srcNdStride = C0_NUM_PER_FRACTAL * layoutSrc.stride(0); + uint32_t ndNum = layoutSrc.shape(0) / C0_NUM_PER_FRACTAL; + uint32_t remains = layoutSrc.shape(0) % C0_NUM_PER_FRACTAL; + if (srcNdStride < STRIDE_LIMIT) { + if (ndNum) { + intriParams.ndNum = ndNum; + intriParams.nValue = C0_NUM_PER_FRACTAL; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = srcNdStride; + intriParams.srcDValue = layoutSrc.stride(0); + + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + + intriParams.dstNzMatrixStride = layoutDst.stride(1); + + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } + + if (remains) { + AscendC::Nd2NzParams tailParams; + tailParams.ndNum = 1; + tailParams.nValue = remains; + tailParams.dValue = layoutSrc.shape(1); + tailParams.srcNdMatrixStride = srcNdStride; + tailParams.srcDValue = layoutSrc.stride(0); + + tailParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + tailParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + tailParams.dstNzMatrixStride = 0; //` + + AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(1)], srcTensor[ndNum * srcNdStride], tailParams); + } + } else if (layoutSrc.stride(0) < STRIDE_LIMIT) { + for (uint32_t i = 0; i < ndNum; i++) { + AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; + intriParams.nValue = C0_NUM_PER_FRACTAL; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.srcDValue = layoutSrc.stride(0); + + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[i * layoutDst.stride(1)], srcTensor[i * srcNdStride], intriParams); + } + if (remains) { + AscendC::Nd2NzParams tailParams; + tailParams.ndNum = 1; + tailParams.nValue = remains; + tailParams.dValue = layoutSrc.shape(1); + tailParams.srcNdMatrixStride = 0; + tailParams.srcDValue = layoutSrc.stride(0); + + tailParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + tailParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + tailParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(1)], srcTensor[ndNum * srcNdStride], tailParams); + } + } else { + for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { + uint32_t idxR0 = i / C0_NUM_PER_FRACTAL; + uint32_t idxInR0 = i % C0_NUM_PER_FRACTAL; + + AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; + intriParams.nValue = 1; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.srcDValue = 0; + + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = 0; + intriParams.dstNzMatrixStride = 0; + + uint32_t offsetDst = i * idxR0 * layoutDst.stride(1) + idxInR0 * ELE_NUM_PER_C0; + uint32_t offsetSrc = i * layoutSrc.stride(0); + AscendC::DataCopy(dstTensor[offsetDst], srcTensor[offsetSrc], intriParams); + } + } + } +}; + +template +struct CopyGmToL1, + Gemm::GemmType> { + using LayoutDst = layout::nN; + using LayoutSrc = layout::ColumnMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + uint32_t srcNdStride = C0_NUM_PER_FRACTAL * layoutSrc.stride(1); + uint32_t ndNum = layoutSrc.shape(1) / C0_NUM_PER_FRACTAL; + uint32_t remains = layoutSrc.shape(1) % C0_NUM_PER_FRACTAL; + if (srcNdStride < STRIDE_LIMIT) { + if (ndNum) { + intriParams.ndNum = ndNum; + intriParams.nValue = C0_NUM_PER_FRACTAL; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = srcNdStride; + intriParams.srcDValue = layoutSrc.stride(1); + + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + + intriParams.dstNzMatrixStride = layoutDst.stride(3); + + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } + + if (remains) { + AscendC::Nd2NzParams tailParams; + tailParams.ndNum = 1; + tailParams.nValue = remains; + tailParams.dValue = layoutSrc.shape(0); + tailParams.srcNdMatrixStride = srcNdStride; + tailParams.srcDValue = layoutSrc.stride(1); + + tailParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + tailParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + tailParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(3)], srcTensor[ndNum * srcNdStride], tailParams); + } + } else if (layoutSrc.stride(1) < STRIDE_LIMIT) { + for (uint32_t i = 0; i < ndNum; i++) { + AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; + intriParams.nValue = C0_NUM_PER_FRACTAL; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.srcDValue = layoutSrc.stride(1); + + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[i * layoutDst.stride(3)], srcTensor[i * srcNdStride], intriParams); + } + if (remains) { + AscendC::Nd2NzParams tailParams; + tailParams.ndNum = 1; + tailParams.nValue = remains; + tailParams.dValue = layoutSrc.shape(0); + tailParams.srcNdMatrixStride = 0; + tailParams.srcDValue = layoutSrc.stride(1); + + tailParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + tailParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + tailParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(3)], srcTensor[ndNum * srcNdStride], tailParams); + } + } else { + for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { + uint32_t idxR0 = i / C0_NUM_PER_FRACTAL; + uint32_t idxInR0 = i % C0_NUM_PER_FRACTAL; + + AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; + intriParams.nValue = 1; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.srcDValue = 0; + + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = 0; + intriParams.dstNzMatrixStride = 0; + + uint32_t offsetDst = i * idxR0 * layoutDst.stride(3) + idxInR0 * ELE_NUM_PER_C0; + uint32_t offsetSrc = i * layoutSrc.stride(1); + AscendC::DataCopy(dstTensor[offsetDst], srcTensor[offsetSrc], intriParams); + } + } + } +}; + +template +struct CopyGmToL1, + Gemm::GemmType> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::ColumnMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (layoutSrc.stride(1) < STRIDE_LIMIT) { + intriParams.nValue = layoutSrc.shape(1); + intriParams.srcDValue = layoutSrc.stride(1); + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(1)], intriParams); + } + } + } +}; + +template +struct CopyGmToL1, + Gemm::GemmType> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::ColumnMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (layoutSrc.stride(1) < STRIDE_LIMIT) { + intriParams.nValue = layoutSrc.shape(1); + intriParams.srcDValue = layoutSrc.stride(1); + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(1)], intriParams); + } + } + } +}; +//////////////////////////////////////// + +/////////////////////////////////////// +/// new add gemv, VectorLayout -> zN +template +struct CopyGmToL1, + Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::VectorLayout; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + intriParams.nValue = 1; + intriParams.srcDValue = layoutSrc.shape(0); + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } +}; + +template +struct CopyGmToL1> { + using LayoutDst = layout::NDC1HWC0; + using LayoutSrc = layout::NDC1HWC0; + + // Methods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + const static uint64_t MAX_UINT16 = 65535; + + uint32_t cin1LoadL1 = layoutDst.orgShape(2); + uint32_t hiLoadL1 = layoutDst.orgShape(3); + + uint32_t dilationD = layoutSrc.orgShape(1); + uint32_t OriC1 = layoutSrc.orgShape(2); + uint32_t OriH = layoutSrc.orgShape(3); + uint32_t OriW = layoutSrc.orgShape(4); + uint32_t OriK0 = layoutSrc.orgShape(5); + + uint64_t dataCopyLoop = CeilDiv(cin1LoadL1, OriC1); + uint64_t dataCopySubLoop = 0; + uint64_t blockCount = OriC1; + bool srcStrideBeyondMaxU16 = false; + if (OriH * OriW - hiLoadL1 * OriW > MAX_UINT16) { + dataCopySubLoop = dataCopyLoop > 0 ? cin1LoadL1 / dataCopyLoop : 0; + blockCount = 1; + srcStrideBeyondMaxU16 = true; + } + + uint64_t aL1GmOffset = 0; + uint64_t aL1Offset = 0; + if (cin1LoadL1 > OriC1 || srcStrideBeyondMaxU16) { + repeatParams.blockCount = blockCount; + repeatParams.blockLen = hiLoadL1 * OriW; + repeatParams.srcStride = OriW * OriH - repeatParams.blockLen; + repeatParams.dstStride = 0; + for (uint64_t i = 0; i < dataCopyLoop; i++) { + if (srcStrideBeyondMaxU16) { + uint64_t aL1GmSubOffset = aL1GmOffset; + uint64_t aL1SubOffset = aL1Offset; + for (uint64_t j = 0; j < dataCopySubLoop; j++) { + AscendC::DataCopy(dstTensor[aL1SubOffset], srcTensor[aL1GmSubOffset], repeatParams); + aL1GmSubOffset += OriH * OriW * OriK0; + aL1SubOffset += hiLoadL1 * OriW * OriK0; + } + } else { + AscendC::DataCopy(dstTensor[aL1Offset], srcTensor[aL1GmOffset], repeatParams); + } + aL1GmOffset += dilationD * OriC1 * OriH * OriW * OriK0; + aL1Offset += OriC1 * hiLoadL1 * OriW * OriK0; + } + } else { + repeatParams.blockCount = cin1LoadL1; + repeatParams.blockLen = hiLoadL1 * OriW; + repeatParams.srcStride = OriW * OriH - repeatParams.blockLen; + repeatParams.dstStride = 0; + AscendC::DataCopy(dstTensor[aL1Offset], srcTensor[aL1GmOffset], repeatParams); + aL1Offset += cin1LoadL1 * hiLoadL1 * OriW * OriK0; + } + } + +private: + uint64_t aL1GmOffset = 0; + AscendC::DataCopyParams repeatParams; +}; + +template +struct CopyGmToL1> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::KDC1KHKWN1N0C0; + + // Methods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + uint32_t currentNBL1 = layoutDst.orgShape(1); + uint32_t currentKBL1 = layoutDst.orgShape(0); + + uint32_t N1 = layoutSrc.shape(3); + uint32_t N0 = layoutSrc.shape(2); + uint32_t C0 = layoutSrc.shape(0); + uint32_t OriCoAlign = N1 * N0; + + const static uint32_t LOAD2D_MAX_REPEAT_TIMES = 255; + + if (currentNBL1 >= OriCoAlign) { + uint32_t repeatTimes = (currentKBL1 * currentNBL1) / (N0 * C0); + if (repeatTimes > LOAD2D_MAX_REPEAT_TIMES) { + repeatParams.blockCount = 1; + repeatParams.srcStride = 0; + repeatParams.blockLen = CeilDiv(currentKBL1 * currentNBL1, C0); + AscendC::DataCopy(dstTensor, srcTensor, repeatParams); + } else { + AscendC::LoadData2DParams loadData2dParams; + loadData2dParams.srcStride = 1; + loadData2dParams.repeatTimes = repeatTimes; + AscendC::LoadData(dstTensor, srcTensor, loadData2dParams); + } + } else { + repeatParams.blockCount = currentKBL1 / C0; + repeatParams.blockLen = currentNBL1; + repeatParams.srcStride = N1 * N0 - currentNBL1; + AscendC::DataCopy(dstTensor, srcTensor, repeatParams); + } + } + +private: + AscendC::DataCopyParams repeatParams; +}; + +/////////////////////////////////////// +/// new add gemv, ColumnMajor -> nN +template +struct CopyGmToL1, + Gemm::GemmType> { + using LayoutDst = layout::nN; + using LayoutSrc = layout::ColumnMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + uint32_t srcNdStride = C0_NUM_PER_FRACTAL * layoutSrc.stride(1); + uint32_t ndNum = layoutSrc.shape(1) / C0_NUM_PER_FRACTAL; + uint32_t remains = layoutSrc.shape(1) % C0_NUM_PER_FRACTAL; + if (srcNdStride < STRIDE_LIMIT) { + if (ndNum) { + intriParams.ndNum = ndNum; + intriParams.nValue = C0_NUM_PER_FRACTAL; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = srcNdStride; + intriParams.srcDValue = layoutSrc.stride(1); + + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + + intriParams.dstNzMatrixStride = layoutDst.stride(3); + + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } + + if (remains) { + AscendC::Nd2NzParams tailParams; + tailParams.ndNum = 1; + tailParams.nValue = remains; + tailParams.dValue = layoutSrc.shape(0); + tailParams.srcNdMatrixStride = srcNdStride; + tailParams.srcDValue = layoutSrc.stride(1); + + tailParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + tailParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + tailParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(3)], srcTensor[ndNum * srcNdStride], tailParams); + } + } else if (layoutSrc.stride(1) < STRIDE_LIMIT) { + for (uint32_t i = 0; i < ndNum; i++) { + AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; + intriParams.nValue = C0_NUM_PER_FRACTAL; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.srcDValue = layoutSrc.stride(1); + + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[i * layoutDst.stride(3)], srcTensor[i * srcNdStride], intriParams); + } + if (remains) { + AscendC::Nd2NzParams tailParams; + tailParams.ndNum = 1; + tailParams.nValue = remains; + tailParams.dValue = layoutSrc.shape(0); + tailParams.srcNdMatrixStride = 0; + tailParams.srcDValue = layoutSrc.stride(1); + + tailParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + tailParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + tailParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(3)], srcTensor[ndNum * srcNdStride], tailParams); + } + } else { + for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { + uint32_t idxR0 = i / C0_NUM_PER_FRACTAL; + uint32_t idxInR0 = i % C0_NUM_PER_FRACTAL; + + AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; + intriParams.nValue = 1; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.srcDValue = 0; + + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = 0; + intriParams.dstNzMatrixStride = 0; + + uint32_t offsetDst = i * idxR0 * layoutDst.stride(3) + idxInR0 * ELE_NUM_PER_C0; + uint32_t offsetSrc = i * layoutSrc.stride(1); + AscendC::DataCopy(dstTensor[offsetDst], srcTensor[offsetSrc], intriParams); + } + } + } +}; + +template +struct CopyGmToL1, + Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (layoutSrc.stride(0) < STRIDE_LIMIT) { + intriParams.nValue = layoutSrc.shape(0); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(0)], intriParams); + } + } + } +}; +///////////////////////////////// + +/// Partial specialization for AtlasA2, RowMajor in and zN out. +template +struct CopyGmToL1> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (layoutSrc.stride(0) < STRIDE_LIMIT) { + intriParams.nValue = layoutSrc.shape(0); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(0)], intriParams); + } + } + } + + // layoutSrc must be the layout of one of the src matrices + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc, uint32_t ndNum, uint32_t srcNdMatrixStride, + uint32_t dstNzNStride, uint32_t dstNzMatrixStride, uint32_t dstNzC0Stride) + { + AscendC::Nd2NzParams intriParams; + + intriParams.nValue = layoutSrc.shape(0); + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = dstNzNStride; + intriParams.dstNzC0Stride = dstNzC0Stride; + if (srcNdMatrixStride < STRIDE_LIMIT) { + intriParams.ndNum = ndNum; + intriParams.srcNdMatrixStride = srcNdMatrixStride; + intriParams.dstNzMatrixStride = dstNzMatrixStride; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.ndNum = 1; + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzMatrixStride = 0; + for (uint32_t i = 0; i < ndNum; i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * srcNdMatrixStride], intriParams); + } + } + } +}; + +/// Partial specialization for AtlasA2, ColumnMajor in and nZ out. +template +struct CopyGmToL1> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::ColumnMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (layoutSrc.stride(1) < STRIDE_LIMIT) { + intriParams.nValue = layoutSrc.shape(1); + intriParams.srcDValue = layoutSrc.stride(1); + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(1)], intriParams); + } + } + } +}; + +/// Partial specialization for zN in and zN out. +template +struct CopyGmToL1> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + uint32_t blockCount = CeilDiv(layoutSrc.orgShape(1)); + uint32_t blockLen = RoundUp(layoutSrc.orgShape(0)); + + AscendC::DataCopyParams repeatParams; + + if (layoutSrc.stride(3) / ELE_NUM_PER_C0 < STRIDE_LIMIT) { + repeatParams.blockCount = blockCount; + repeatParams.blockLen = blockLen; + repeatParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_C0 - blockLen; + repeatParams.dstStride = layoutDst.stride(3) / ELE_NUM_PER_C0 - blockLen; + AscendC::DataCopy(dstTensor, srcTensor, repeatParams); + } else { + repeatParams.blockCount = 1; + repeatParams.blockLen = blockLen; + repeatParams.srcStride = 0; + repeatParams.dstStride = 0; + for (uint32_t i = 0; i < blockCount; i++) { + uint64_t dstOffset = i * layoutDst.stride(3); + uint64_t srcOffset = i * layoutSrc.stride(3); + AscendC::DataCopy(dstTensor[dstOffset], srcTensor[srcOffset], repeatParams); + } + } + } +}; + +/// Partial specialization for nZ in and nZ out. +template +struct CopyGmToL1> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + uint32_t blockCount = CeilDiv(layoutSrc.orgShape(0)); + uint32_t blockLen = RoundUp(layoutSrc.orgShape(1)); + + AscendC::DataCopyParams repeatParams; + + if (layoutSrc.stride(1) / ELE_NUM_PER_C0 < STRIDE_LIMIT) { + repeatParams.blockCount = blockCount; + repeatParams.blockLen = blockLen; + repeatParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_C0 - blockLen; + repeatParams.dstStride = layoutDst.stride(1) / ELE_NUM_PER_C0 - blockLen; + AscendC::DataCopy(dstTensor, srcTensor, repeatParams); + } else { + repeatParams.blockCount = 1; + repeatParams.blockLen = blockLen; + repeatParams.srcStride = 0; + repeatParams.dstStride = 0; + for (uint32_t i = 0; i < blockCount; i++) { + uint64_t dstOffset = i * layoutDst.stride(1); + uint64_t srcOffset = i * layoutSrc.stride(1); + AscendC::DataCopy(dstTensor[dstOffset], srcTensor[srcOffset], repeatParams); + } + } + } +}; + +/// Partial specialization for AtlasA2, PaddingRowMajor in and zN out. +template +struct CopyGmToL1> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::PaddingRowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.orgShape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + intriParams.nValue = layoutSrc.orgShape(0); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } +}; + +/// Partial specialization for AtlasA2, ColumnMajor in and nZ out. +template +struct CopyGmToL1> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::PaddingColumnMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.orgShape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + intriParams.nValue = layoutSrc.orgShape(1); + intriParams.srcDValue = layoutSrc.stride(2); + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } +}; + +/// Partial specialization for AtlasA2, RowMajor in and RowMajor out. +template +struct CopyGmToL1, + Gemm::GemmType> { + using LayoutDst = layout::RowMajor; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + static constexpr uint32_t BLOCK_LEN_LIMIT = 65536; + static constexpr uint32_t MAX_REPEAT = 4095; + + // Methods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + uint32_t rows = layoutSrc.shape(0); + uint32_t cols = layoutSrc.shape(1); + uint32_t srcStride = (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_BLK; + uint32_t dstStride = (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; + + if ((layoutSrc.shape(1) == layoutSrc.stride(0)) && (layoutDst.shape(1) == layoutDst.stride(0))) { + DataCopy(dstTensor, srcTensor, rows * cols); + } else if (srcStride < STRIDE_LIMIT && dstStride < STRIDE_LIMIT && (cols / ELE_NUM_PER_BLK) < BLOCK_LEN_LIMIT) { + uint32_t rLoops = CeilDiv(rows, MAX_REPEAT); + for (uint32_t i = 0; i < rLoops; ++i) { + uint32_t rActual = (i < rLoops - 1) ? MAX_REPEAT : rows - i * MAX_REPEAT; + AscendC::DataCopyParams dataCopyParams(rActual, cols / ELE_NUM_PER_BLK, srcStride, dstStride); + DataCopy(dstTensor[i * MAX_REPEAT * layoutDst.stride(0)], + srcTensor[i * MAX_REPEAT * layoutSrc.stride(0)], dataCopyParams); + } + } else { + for (uint32_t i = 0; i < rows; ++i) { + DataCopy(dstTensor[i * layoutDst.stride(0)], srcTensor[i * layoutSrc.stride(0)], cols); + } + } + } +}; + +///////////////////////////////////////////TileCopyTla////////////////////////////////////////////////////// +/// Partial specialization for CopyGmToL1, AtlasA2, RowMajor in and zN out. +template +struct TileCopyTla< + Arch::AtlasA2, tla::Tensor, LayoutSrc, CoordSrc, AscendC::TPosition::GM>, + tla::Tensor, LayoutDst, CoordDst, AscendC::TPosition::A1>, + std::enable_if_t::value && tla::detail::iszN::value>> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + CATLASS_DEVICE + TileCopyTla() {}; + + template + CATLASS_DEVICE void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + static_assert( + tla::detail::isRowMajor::value && + tla::detail::iszN::value && + TensorSrc::position == AscendC::TPosition::GM && TensorDst::position == AscendC::TPosition::A1, + "The input parameters do not match. TensorSrc must be GM and RowMajor, while TensorDst must be L1 and zN"); + + const uint32_t nValue = tla::get<0>(srcTensor.shape()); + const uint32_t dValue = tla::get<1>(srcTensor.shape()); + const uint32_t srcDValue = tla::get<0>(srcTensor.stride()); + const uint32_t dstInnerStrideRow = tla::get<0, 0>(dstTensor.stride()); + const uint32_t dstOuterStrideCol = tla::get<1, 1>(dstTensor.stride()); + + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = dValue; + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = dstOuterStrideCol / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + auto dstOffset = dstTensor.layout()(dstTensor.coord()); + auto srcOffset = srcTensor.layout()(srcTensor.coord()); + + if (srcDValue < STRIDE_LIMIT) { + intriParams.nValue = nValue; + intriParams.srcDValue = srcDValue; + intriParams.dstNzNStride = dstInnerStrideRow / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor.data()[dstOffset], srcTensor.data()[srcOffset], intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < nValue; i++) { + AscendC::DataCopy(dstTensor.data()[dstOffset + i * ELE_NUM_PER_C0], + srcTensor.data()[srcOffset + i * srcDValue], intriParams); + } + } + } +}; + +/// Partial specialization for CopyGmToL1, AtlasA2, ColumnMajor in and nZ out. +template +struct TileCopyTla< + Arch::AtlasA2, tla::Tensor, LayoutSrc, CoordSrc, AscendC::TPosition::GM>, + tla::Tensor, LayoutDst, CoordDst, AscendC::TPosition::A1>, + std::enable_if_t::value && tla::detail::isnZ::value>> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + CATLASS_DEVICE + TileCopyTla() {}; + + template + CATLASS_DEVICE void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + static_assert(tla::detail::isColumnMajor::value && + tla::detail::isnZ::value && + TensorSrc::position == AscendC::TPosition::GM && + TensorDst::position == AscendC::TPosition::A1, + "The input parameters do not match. TensorSrc must be GM and ColumnMajor, " + "while TensorDst must be L1 and nZ"); + + const uint32_t nValue = tla::get<1>(srcTensor.shape()); + const uint32_t dValue = tla::get<0>(srcTensor.shape()); + const uint32_t srcDValue = tla::get<1>(srcTensor.stride()); + const uint32_t dstInnerStrideRow = tla::get<1, 0>(dstTensor.stride()); + const uint32_t dstOuterStrideCol = tla::get<0, 1>(dstTensor.stride()); + + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = dValue; + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = dstOuterStrideCol / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + auto dstOffset = dstTensor.layout()(dstTensor.coord()); + auto srcOffset = srcTensor.layout()(srcTensor.coord()); + + if (srcDValue < STRIDE_LIMIT) { + intriParams.nValue = nValue; + intriParams.srcDValue = srcDValue; + intriParams.dstNzNStride = dstInnerStrideRow / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor.data()[dstOffset], srcTensor.data()[srcOffset], intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < nValue; i++) { + AscendC::DataCopy(dstTensor.data()[dstOffset + i * ELE_NUM_PER_C0], + srcTensor.data()[srcOffset + i * srcDValue], intriParams); + } + } + } +}; + +/// Partial specialization for TileCopyTlaExt, CopyGmToL1, AtlasA2, PaddingRowMajor in and zN out. +template +struct TileCopyTlaExt, LayoutSrc, CoordSrc, AscendC::TPosition::GM>, + tla::Tensor, LayoutDst, CoordDst, AscendC::TPosition::A1>, + layout::RowMajor, layout::zN> { + using ActualShape = tla::Shape; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + CATLASS_DEVICE + TileCopyTlaExt() {}; + + template + CATLASS_DEVICE void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor, ActualShape actualShape) + { + static_assert( + tla::detail::isRowMajor::value && + tla::detail::iszN::value && + TensorSrc::position == AscendC::TPosition::GM && TensorDst::position == AscendC::TPosition::A1, + "The input parameters do not match. TensorSrc must be GM and RowMajor, while TensorDst must be L1 and zN"); + + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = tla::get<1>(actualShape); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = tla::get<1, 1>(dstTensor.stride()) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + intriParams.nValue = tla::get<0>(actualShape); + intriParams.srcDValue = tla::get<0>(srcTensor.stride()); + intriParams.dstNzNStride = tla::get<0, 0>(dstTensor.stride()) / ELE_NUM_PER_C0; + auto dstOffset = dstTensor.layout()(dstTensor.coord()); + auto srcOffset = srcTensor.layout()(srcTensor.coord()); + AscendC::DataCopy(dstTensor.data()[dstOffset], srcTensor.data()[srcOffset], intriParams); + } +}; + +/// Partial specialization for TileCopyTlaExt, CopyGmToL1, AtlasA2, PaddingRowMajor in and zN out. +template +struct TileCopyTlaExt, LayoutSrc, CoordSrc, AscendC::TPosition::GM>, + tla::Tensor, LayoutDst, CoordDst, AscendC::TPosition::A1>, + layout::PaddingRowMajor, layout::zN> { + using ActualShape = tla::Shape; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + CATLASS_DEVICE + TileCopyTlaExt() {}; + + template + CATLASS_DEVICE void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor, ActualShape actualShape) + { + static_assert(tla::detail::iszN::value && + TensorSrc::position == AscendC::TPosition::GM && + TensorDst::position == AscendC::TPosition::A1, + "The input parameters do not match. TensorSrc must be GM and PaddingRowMajor, " + "while TensorDst must be L1 and zN"); + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = tla::get<1>(actualShape); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = tla::get<1, 1>(dstTensor.stride()) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + intriParams.nValue = tla::get<0>(actualShape); + intriParams.srcDValue = tla::get<0, 0>(srcTensor.stride()); + intriParams.dstNzNStride = tla::get<0, 0>(dstTensor.stride()) / ELE_NUM_PER_C0; + auto dstOffset = dstTensor.layout()(dstTensor.coord()); + auto srcOffset = srcTensor.layout()(srcTensor.coord()); + AscendC::DataCopy(dstTensor.data()[dstOffset], srcTensor.data()[srcOffset], intriParams); + } +}; + +/// Partial specialization for TileCopyTlaExt, CopyGmToL1, AtlasA2, PaddingColumnMajor in and nZ out. +template +struct TileCopyTlaExt, LayoutSrc, CoordSrc, AscendC::TPosition::GM>, + tla::Tensor, LayoutDst, CoordDst, AscendC::TPosition::A1>, + layout::ColumnMajor, layout::nZ> { + using ActualShape = tla::Shape; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + CATLASS_DEVICE + TileCopyTlaExt() {}; + + template + CATLASS_DEVICE void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor, ActualShape actualShape) + { + static_assert(tla::detail::isColumnMajor::value && + tla::detail::isnZ::value && + TensorSrc::position == AscendC::TPosition::GM && + TensorDst::position == AscendC::TPosition::A1, + "The input parameters do not match. TensorSrc must be GM and ColumnMajor, " + "while TensorDst must be L1 and nZ"); + + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = tla::get<0>(actualShape); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = tla::get<0, 1>(dstTensor.stride()) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + intriParams.nValue = tla::get<1>(actualShape); + intriParams.srcDValue = tla::get<1>(srcTensor.stride()); + intriParams.dstNzNStride = tla::get<1, 0>(dstTensor.stride()) / ELE_NUM_PER_C0; + auto dstOffset = dstTensor.layout()(dstTensor.coord()); + auto srcOffset = srcTensor.layout()(srcTensor.coord()); + AscendC::DataCopy(dstTensor.data()[dstOffset], srcTensor.data()[srcOffset], intriParams); + } +}; + +/// Partial specialization for TileCopyTlaExt, CopyGmToL1, AtlasA2, PaddingColumnMajor in and nZ out. +template +struct TileCopyTlaExt, LayoutSrc, CoordSrc, AscendC::TPosition::GM>, + tla::Tensor, LayoutDst, CoordDst, AscendC::TPosition::A1>, + layout::PaddingColumnMajor, layout::nZ> { + using ActualShape = tla::Shape; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + CATLASS_DEVICE + TileCopyTlaExt() {}; + + template + CATLASS_DEVICE void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor, ActualShape actualShape) + { + static_assert(tla::detail::isnZ::value && + TensorSrc::position == AscendC::TPosition::GM && + TensorDst::position == AscendC::TPosition::A1, + "The input parameters do not match. TensorSrc must be GM and PaddingColumnMajor, " + "while TensorDst must be L1 and nZ"); + + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = tla::get<0>(actualShape); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = tla::get<0, 1>(dstTensor.stride()) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + intriParams.nValue = tla::get<1>(actualShape); + intriParams.srcDValue = tla::get<1, 0>(srcTensor.stride()); + intriParams.dstNzNStride = tla::get<1, 0>(dstTensor.stride()) / ELE_NUM_PER_C0; + auto dstOffset = dstTensor.layout()(dstTensor.coord()); + auto srcOffset = srcTensor.layout()(srcTensor.coord()); + AscendC::DataCopy(dstTensor.data()[dstOffset], srcTensor.data()[srcOffset], intriParams); + } +}; + +template +struct CopyGmToL1, + Gemm::GemmType> { + using LayoutDst = layout::VectorLayout; + using LayoutSrc = layout::VectorLayout; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyGmToL1() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = layoutDst.shape(0) / ELE_NUM_PER_C0; + intriParams.srcStride = 0; + intriParams.dstStride = 0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Gemm::Tile + +#endif // CATLASS_GEMM_TILE_COPY_GM_TO_L1_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_gm_to_ub.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_gm_to_ub.hpp new file mode 100644 index 00000000..fa43b3d2 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_gm_to_ub.hpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_TILE_COPY_GM_TO_UB_HPP +#define CATLASS_GEMM_TILE_COPY_GM_TO_UB_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/gemm/tile/tile_copy_tla.hpp" +#include "tla/tensor.hpp" + +namespace Catlass::Gemm::Tile { + +/// Partial specialization for AtlasA2, RowMajor in and RowMajor out. +template +struct TileCopyTla< + Arch::AtlasA2, tla::Tensor, LayoutSrc, CoordSrc, AscendC::TPosition::GM>, + tla::Tensor, LayoutDst, CoordDst, AscendC::TPosition::VECCALC>, + std::enable_if_t::value && tla::detail::isRowMajor::value>> { + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(ElementSrc); + + // Methods + + CATLASS_DEVICE + TileCopyTla() {}; + + template + CATLASS_DEVICE void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + static_assert(tla::detail::isRowMajor::value && + tla::detail::isRowMajor::value && + TensorSrc::position == AscendC::TPosition::GM && + TensorDst::position == AscendC::TPosition::VECCALC, + "The input parameters do not match. TensorSrc must be GM and RowMajor, " + "while TensorDst must be UB and RowMajor"); + + AscendC::DataCopyExtParams dataCopyParams( + tla::get<0>(srcTensor.shape()), tla::get<1>(srcTensor.shape()) * sizeof(ElementSrc), + (tla::get<0>(srcTensor.stride()) - tla::get<1>(srcTensor.shape())) * sizeof(ElementSrc), + (tla::get<0>(dstTensor.stride()) - tla::get<1>(dstTensor.shape())) / ELE_NUM_PER_BLK, 0); + AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); + auto dstOffset = dstTensor.layout()(dstTensor.coord()); + auto srcOffset = srcTensor.layout()(srcTensor.coord()); + AscendC::DataCopyPad(dstTensor.data()[dstOffset], srcTensor.data()[srcOffset], dataCopyParams, padParams); + }; +}; + +} // namespace Catlass::Gemm::Tile + +#endif // CATLASS_GEMM_TILE_COPY_GM_TO_UB_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_l0c_to_gm.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_l0c_to_gm.hpp new file mode 100644 index 00000000..6509368a --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_l0c_to_gm.hpp @@ -0,0 +1,230 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_TILE_COPY_L0C_TO_GM_HPP +#define CATLASS_GEMM_TILE_COPY_L0C_TO_GM_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/gemm/gemm_type.hpp" +#include "tla/tensor.hpp" + +namespace Catlass::Gemm::Tile { + +enum class ScaleGranularity { UNDEFINED = -1, NO_QUANT = 0, PER_TENSOR, PER_CHANNEL, PER_GROUP }; + +template +struct CopyL0CToGmQuantMode { + static_assert(DEPENDENT_FALSE, "Unsupported copy l0c to gm, can not find the specialization."); +}; + +// CopyL0CToGm cast fp32 to fp16 +template <> +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::F322F16; +}; + +// CopyL0CToGm cast fp32 to bf16 +template <> +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::F322BF16; +}; + +// CopyL0CToGm output fp32 +template <> +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::NoQuant; +}; + +// CopyL0CToGm output int32 +template <> +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::NoQuant; +}; + +// CopyL0CToGm cast int32_t to fp16 +template <> +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::DEQF16; +}; + +template <> +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::VDEQF16; +}; + +template +struct CopyL0CToGm { + static_assert(DEPENDENT_FALSE, "Unsupported copy l0c to gm, can not find the specialization."); +}; + +template +struct CopyL0CToGm, + ScaleGranularity::NO_QUANT, ReluEnable_> { + using ArchTag = Catlass::Arch::AtlasA2; + using ElementDst = ElementDst_; + using ElementSrc = ElementAccumulator_; + using LayoutSrc = Catlass::layout::zN; + using LayoutDst = Catlass::layout::RowMajor; + static constexpr auto quantPre = + CopyL0CToGmQuantMode::VALUE; + static constexpr auto reluEn = ReluEnable_; + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &dst, AscendC::LocalTensor const &src, + LayoutDst const &dstLayout, LayoutSrc const &srcLayout, uint8_t unitFlag = 0) + { + AscendC::FixpipeParamsV220 intriParams; + + // Fixpipe layout information + intriParams.nSize = dstLayout.shape(1); + intriParams.mSize = dstLayout.shape(0); + intriParams.srcStride = srcLayout.stride(3) / srcLayout.stride(0); + intriParams.dstStride = dstLayout.stride(0); + + // Fixpipe auxiliary arguments + intriParams.quantPre = quantPre; + intriParams.reluEn = reluEn; + intriParams.unitFlag = unitFlag; + + // Call AscendC Fixpipe + AscendC::Fixpipe(dst, src, intriParams); + } +}; + +template +struct CopyL0CToGm, + ScaleGranularity::NO_QUANT, ReluEnable_> { + using ArchTag = Catlass::Arch::AtlasA2; + using ElementDst = ElementDst_; + using ElementSrc = ElementAccumulator_; + using LayoutSrc = Catlass::layout::zN; + using LayoutDst = Catlass::layout::zN; + static constexpr auto quantPre = + CopyL0CToGmQuantMode::VALUE; + static constexpr auto reluEn = ReluEnable_; + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &dst, AscendC::LocalTensor const &src, + LayoutDst const &dstLayout, LayoutSrc const &srcLayout, uint8_t unitFlag = 0) + { + AscendC::FixpipeParamsV220 intriParams; + + // Fixpipe layout information + intriParams.nSize = dstLayout.shape(2) * dstLayout.shape(3); + intriParams.mSize = dstLayout.shape(0) * dstLayout.shape(1); + intriParams.srcStride = srcLayout.stride(3) / srcLayout.shape(2); + intriParams.dstStride = dstLayout.stride(3) / (BYTE_PER_C0 / sizeof(ElementDst)); + + // Fixpipe auxiliary arguments + intriParams.quantPre = quantPre; + intriParams.reluEn = reluEn; + intriParams.unitFlag = unitFlag; + + // Call AscendC Fixpipe + AscendC::Fixpipe(dst, src, intriParams); + } +}; + +template +struct CopyL0CToGm, + ScaleGranularity::NO_QUANT, ReluEnable_> { + using ArchTag = Catlass::Arch::AtlasA2; + using ElementDst = ElementDst_; + using ElementSrc = ElementAccumulator_; + using LayoutSrc = Catlass::layout::zN; + using LayoutDst = Catlass::layout::NDC1HWC0; + static constexpr auto quantPre = + CopyL0CToGmQuantMode::VALUE; + static constexpr auto reluEn = ReluEnable_; + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &dst, AscendC::LocalTensor const &src, + LayoutDst const &dstLayout, LayoutSrc const &srcLayout, uint8_t unitFlag = 0) + { + AscendC::FixpipeParamsV220 intriParams; + + intriParams.nSize = srcLayout.orgShape(1); + intriParams.mSize = srcLayout.orgShape(0); + intriParams.srcStride = srcLayout.stride(3) / srcLayout.shape(2); + intriParams.dstStride = dstLayout.shape(1) * dstLayout.shape(2); + + if constexpr (AscendC::IsSameType::value && AscendC::IsSameType::value) { + intriParams.isChannelSplit = true; + } + + intriParams.quantPre = quantPre; + intriParams.reluEn = false; + intriParams.unitFlag = unitFlag; + AscendC::Fixpipe(dst, src, intriParams); + } +}; + +///////////////////////////////////////////CopyL0CToGmTla///////////////////////////////////////////////// +// L0C copy mode +struct CopyToGM {}; +struct CopyToL1 {}; + +template +struct CopyL0CToGmTla { + static_assert(DEPENDENT_FALSE, "Unsupported copy l0c to gm, can not find the specialization."); +}; + +template +struct CopyL0CToGmTla, LayoutDst_, CoordDst_, AscendC::TPosition::GM>, + ScaleGranularity::NO_QUANT, ReluEnable_, + std::enable_if_t::value>> { + using ArchTag = Catlass::Arch::AtlasA2; + using ElementDst = ElementDst_; + using ElementSrc = typename TensorSrc_::Element; + static constexpr auto quantPre = + CopyL0CToGmQuantMode::VALUE; + static constexpr auto reluEn = ReluEnable_; + + template + CATLASS_DEVICE void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor, uint8_t unitFlag = 0) + { + static_assert( + tla::detail::isRowMajor::value && + TensorSrc::position == AscendC::TPosition::CO1 && TensorDst::position == AscendC::TPosition::GM, + "The input parameters do not match. TensorSrc must be L0C, while TensorDst must be GM and RowMajor"); + + AscendC::FixpipeParamsV220 intriParams; + + // Fixpipe layout information + intriParams.nSize = tla::get<1>(dstTensor.shape()); + intriParams.mSize = tla::get<0>(dstTensor.shape()); + intriParams.srcStride = tla::get<1, 1>(srcTensor.stride()) / tla::get<0, 0>(srcTensor.stride()); + intriParams.dstStride = tla::get<0>(dstTensor.stride()); + + // Fixpipe auxiliary arguments + intriParams.quantPre = quantPre; + intriParams.reluEn = reluEn; + intriParams.unitFlag = unitFlag; + + auto dstOffset = dstTensor.layout()(dstTensor.coord()); + auto srcOffset = srcTensor.layout()(srcTensor.coord()); + + // Call AscendC Fixpipe + AscendC::Fixpipe(dstTensor.data()[dstOffset], + srcTensor.data()[srcOffset], intriParams); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Gemm::Tile + +#endif // CATLASS_GEMM_TILE_COPY_L0C_TO_GM_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_l1_to_bt.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_l1_to_bt.hpp new file mode 100644 index 00000000..6f9e5ac0 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_l1_to_bt.hpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_TILE_COPY_L1_TO_BT_HPP +#define CATLASS_GEMM_TILE_COPY_L1_TO_BT_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/gemm/gemm_type.hpp" +#include "tla/tensor.hpp" + +using namespace tla; + +namespace Catlass::Gemm::Tile { + +template +struct CopyL1ToBT { + static_assert(DEPENDENT_FALSE, + "Unsupported copy l1 to biasTable buffer, can not find the specialization."); +}; + +template +struct CopyL1ToBT, + Catlass::Gemm::GemmType> { + using LayoutDst = layout::VectorLayout; + using LayoutSrc = layout::VectorLayout; + + static constexpr uint32_t ELE_NUM_PER_C2 = BYTE_PER_C2 / sizeof(ElementSrc); + + CATLASS_DEVICE + CopyL1ToBT() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::DataCopyParams intriParams; + intriParams.blockCount = 1; + intriParams.blockLen = (layoutDst.shape(0) + ELE_NUM_PER_C2 - 1) / ELE_NUM_PER_C2; + intriParams.srcStride = 0; + intriParams.dstStride = 0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Gemm::Tile + +#endif // CATLASS_GEMM_TILE_COPY_L1_TO_BT_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_l1_to_l0a.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_l1_to_l0a.hpp new file mode 100644 index 00000000..5fe77b8e --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_l1_to_l0a.hpp @@ -0,0 +1,531 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_TILE_COPY_L1_TO_L0A_HPP +#define CATLASS_GEMM_TILE_COPY_L1_TO_L0A_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/gemm/gemm_type.hpp" +#include "catlass/gemm/tile/tile_copy_tla.hpp" +#include "tla/tensor.hpp" + +namespace Catlass::Gemm::Tile { + +template +struct CopyL1ToL0A { + static_assert(DEPENDENT_FALSE, "Unsupported copy l1 to l0, can not find the specialization."); +}; + +//////////////////////////////// +/// new add gemm +template +struct CopyL1ToL0A, + Catlass::Gemm::GemmType> { + using LayoutDst = layout::zZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + CATLASS_DEVICE + CopyL1ToL0A() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } + } +}; + +template +struct CopyL1ToL0A, + Catlass::Gemm::GemmType> { + using LayoutDst = layout::zZ; + using LayoutSrc = layout::nN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + CATLASS_DEVICE + CopyL1ToL0A() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = static_cast(CeilDiv(layoutSrc.orgShape(0))); + ; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + for (uint32_t i = 0; i < CeilDiv(layoutSrc.orgShape(0)); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } + } +}; + +template +struct CopyL1ToL0A, + Catlass::Gemm::GemmType> { + using Element = float; + using LayoutDst = layout::zZ; + using LayoutSrc = layout::nN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + CATLASS_DEVICE + CopyL1ToL0A() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = static_cast(CeilDiv(layoutSrc.orgShape(0))); + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + for (uint32_t i = 0; i < CeilDiv(layoutSrc.orgShape(0)); i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1) * 2], + loadDataParams); + } + } +}; + +template +struct CopyL1ToL0A, + Catlass::Gemm::GemmType> { + using Element = int8_t; + using LayoutDst = layout::zZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + CATLASS_DEVICE + CopyL1ToL0A() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = CeilDiv(layoutDst.orgShape(1)) - 1; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(1) * 2], srcTensor[i * layoutSrc.stride(1)], + loadDataParams); + } + } +}; + +template +struct CopyL1ToL0A> { + using LayoutDst = layout::zZ; + using LayoutSrc = layout::NDC1HWC0; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint8_t RIGHT_MOVE_8 = 8; + + // Methods + + CATLASS_DEVICE + CopyL1ToL0A(uint32_t strideW = 0, uint32_t strideH = 0, uint32_t filterW = 0, uint32_t filterH = 0, + uint32_t dilationFilterW = 0, uint32_t dilationFilterH = 0) + { + loadData3Dv2Params.strideW = strideW; + loadData3Dv2Params.strideH = strideH; + loadData3Dv2Params.filterW = filterW; + loadData3Dv2Params.filterSizeW = filterW >> RIGHT_MOVE_8; + loadData3Dv2Params.filterH = filterH; + loadData3Dv2Params.filterSizeH = filterH >> RIGHT_MOVE_8; + loadData3Dv2Params.dilationFilterW = dilationFilterW; + loadData3Dv2Params.dilationFilterH = dilationFilterH; + } + + CATLASS_DEVICE + static CopyL1ToL0A MakeCopyL1ToL0A(uint32_t strideW = 0, uint32_t strideH = 0, uint32_t filterW = 0, + uint32_t filterH = 0, uint32_t dilationFilterW = 0, uint32_t dilationFilterH = 0) + { + return CopyL1ToL0A(strideW, strideH, filterW, filterH, dilationFilterW, dilationFilterH); + } + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc, uint32_t kStartPt, uint32_t mStartPt) + { + loadData3Dv2Params.kStartPt = kStartPt; + loadData3Dv2Params.mStartPt = mStartPt; + loadData3Dv2Params.kExtension = layoutDst.orgShape(1); + loadData3Dv2Params.mExtension = layoutDst.orgShape(0); + loadData3Dv2Params.channelSize = layoutSrc.orgShape(1) * layoutSrc.orgShape(2) * layoutSrc.orgShape(5); + static constexpr AscendC::IsResetLoad3dConfig CONV3D_LOAD3DV2_DEFAULT_CONFIG = {false, false}; + AscendC::LoadData(dstTensor, srcTensor, loadData3Dv2Params); + } + +private: + AscendC::LoadData3DParamsV2 loadData3Dv2Params; +}; + +////////////////////////////////////////// + +/// Partial specialization for zN in and zZ out. +template +struct CopyL1ToL0A> { + using LayoutDst = layout::zZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyL1ToL0A() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } + } +}; + +/// Partial specialization for float, zN in and zZ out. +template +struct CopyL1ToL0A> { + using Element = float; + using LayoutDst = layout::zZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyL1ToL0A() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + constexpr uint8_t PAD_LIST[4] = {0, 0, 0, 0}; + uint16_t l1M = layoutSrc.shape(0) * layoutSrc.shape(1); + uint16_t l1K = layoutSrc.shape(2) * layoutSrc.shape(3); + uint16_t l0M = layoutDst.shape(0) * layoutDst.shape(1); + uint16_t l0K = layoutDst.shape(2) * layoutDst.shape(3); + AscendC::SetFmatrix(1, l1M, PAD_LIST, AscendC::FmatrixMode::FMATRIX_LEFT); + static constexpr AscendC::IsResetLoad3dConfig config = {false, false}; + AscendC::LoadData3DParamsV2 loadDataParams; + loadDataParams.kExtension = l0K; + loadDataParams.mExtension = l0M; + loadDataParams.channelSize = l1K; + + AscendC::LoadData(dstTensor, srcTensor, loadDataParams); + } +}; + +template +struct CopyL1ToL0A> { + using LayoutDst = layout::zZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + CATLASS_DEVICE + CopyL1ToL0A() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } + } +}; + +/// Partial specialization for int8_t, nZ in and zZ out. (Transpose A) +template +struct CopyL1ToL0A> { + using Element = int8_t; + using LayoutDst = layout::zZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyL1ToL0A() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = CeilDiv(layoutDst.orgShape(1)) - 1; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(1) * 2], srcTensor[i * layoutSrc.stride(1)], + loadDataParams); + } + } +}; + +/// Partial specialization for float, nZ in and zZ out. (Transpose A) +template +struct CopyL1ToL0A> { + using Element = float; + using LayoutDst = layout::zZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyL1ToL0A() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + constexpr uint8_t PAD_LIST[4] = {0, 0, 0, 0}; + uint16_t l1M = layoutSrc.shape(0) * layoutSrc.shape(1); + uint16_t l1K = layoutSrc.shape(2) * layoutSrc.shape(3); + uint16_t l0M = layoutDst.shape(0) * layoutDst.shape(1); + uint16_t l0K = layoutDst.shape(2) * layoutDst.shape(3); + // K, M need to be 16 aligned for f32 + uint16_t l1MAlign = RoundUp(l1M); + uint16_t l1KAlign = RoundUp(l1K); + uint16_t l0MAlign = RoundUp(l0M); + uint16_t l0KAlign = RoundUp(l0K); + AscendC::SetFmatrix(1, l1KAlign, PAD_LIST, AscendC::FmatrixMode::FMATRIX_LEFT); + static constexpr AscendC::IsResetLoad3dConfig config = {false, false}; + AscendC::LoadData3DParamsV2 loadDataParams; + loadDataParams.kExtension = l0MAlign; + loadDataParams.mExtension = l0KAlign; + loadDataParams.enTranspose = true; + loadDataParams.channelSize = l1MAlign; + + AscendC::LoadData(dstTensor, srcTensor, loadDataParams); + } +}; + +///////////////////////////////////////////TileCopyTla////////////////////////////////////////////////////// + +/// Partial specialization for CopyL1ToL0A, AtlasA2, zN in and zZ out. +template +struct TileCopyTla, LayoutSrc, CoordSrc, AscendC::TPosition::A1>, + tla::Tensor, LayoutDst, CoordDst, AscendC::TPosition::A2>, + std::enable_if_t::value && + tla::detail::iszZ::value>> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(ElementSrc); + + // Methods + + CATLASS_DEVICE + TileCopyTla() {}; + + template + CATLASS_DEVICE void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + static_assert( + tla::detail::iszN::value && + tla::detail::iszZ::value && + TensorSrc::position == AscendC::TPosition::A1 && TensorDst::position == AscendC::TPosition::A2, + "The input parameters do not match. TensorSrc must be L1 and zN, while TensorDst must be L0A and zZ"); + + const uint32_t srcOuterStrideRow = tla::get<0, 1>(srcTensor.stride()); + const uint32_t srcOuterStrideCol = tla::get<1, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = tla::get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterShapeCol = tla::get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = tla::get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + auto dstOffset = dstTensor.layout()(dstTensor.coord()); + auto srcOffset = srcTensor.layout()(srcTensor.coord()); + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadData(dstTensor.data()[dstOffset + i * dstOuterStrideRow], + srcTensor.data()[srcOffset + i * srcOuterStrideRow], loadDataParams); + } + } +}; + +/// Partial specialization for CopyL1ToL0A, AtlasA2, nZ in and zZ out. (Transpose A) +template +struct TileCopyTla, LayoutSrc, CoordSrc, AscendC::TPosition::A1>, + tla::Tensor, LayoutDst, CoordDst, AscendC::TPosition::A2>, + std::enable_if_t::value && + tla::detail::iszZ::value>> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(ElementSrc); + + // Methods + + CATLASS_DEVICE + TileCopyTla() {}; + + template + CATLASS_DEVICE void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + static_assert( + tla::detail::isnZ::value && + tla::detail::iszZ::value && + TensorSrc::position == AscendC::TPosition::A1 && TensorDst::position == AscendC::TPosition::A2, + "The input parameters do not match. TensorSrc must be L1 and nZ, while TensorDst must be L0A and zZ"); + + const uint32_t srcOuterStrideRow = tla::get<0, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = tla::get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterShapeCol = tla::get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = tla::get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = 1; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + + auto dstOffset = dstTensor.layout()(dstTensor.coord()); + auto srcOffset = srcTensor.layout()(srcTensor.coord()); + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadData(dstTensor.data()[dstOffset + i * dstOuterStrideRow], + srcTensor.data()[srcOffset + i * srcOuterStrideRow], loadDataParams); + } + } +}; + +/// Partial specialization for CopyL1ToL0A, AtlasA2, int8_t, nZ in and zZ out. (Transpose A) +template +struct TileCopyTla< + Arch::AtlasA2, tla::Tensor, LayoutSrc, CoordSrc, AscendC::TPosition::A1>, + tla::Tensor, LayoutDst, CoordDst, AscendC::TPosition::A2>, + std::enable_if_t::value && tla::detail::iszZ::value>> { + using Element = int8_t; + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + CATLASS_DEVICE + TileCopyTla() {}; + + template + CATLASS_DEVICE void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + static_assert(std::is_same_v && + std::is_same_v && + tla::detail::isnZ::value && + tla::detail::iszZ::value && + TensorSrc::position == AscendC::TPosition::A1 && + TensorDst::position == AscendC::TPosition::A2, + "The input parameters do not match. TensorSrc must be int8_t, L1 and nZ, " + "while TensorDst must be int8_t, L0A and zZ"); + + const uint32_t srcOuterShapeRow = tla::get<0, 1>(srcTensor.shape()); + const uint32_t srcOuterStrideRow = tla::get<0, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeCol = tla::get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = tla::get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = dstOuterShapeCol - 1; + + auto dstOffset = dstTensor.layout()(dstTensor.coord()); + auto srcOffset = srcTensor.layout()(srcTensor.coord()); + + for (uint32_t i = 0; i < srcOuterShapeRow; i++) { + AscendC::LoadDataWithTranspose(dstTensor.data()[dstOffset + i * dstOuterStrideRow * 2], + srcTensor.data()[srcOffset + i * srcOuterStrideRow], loadDataParams); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Gemm::Tile + +#endif // CATLASS_GEMM_TILE_COPY_L1_TO_L0A_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_l1_to_l0b.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_l1_to_l0b.hpp new file mode 100644 index 00000000..ba843a83 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_l1_to_l0b.hpp @@ -0,0 +1,602 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_TILE_COPY_L1_TO_L0B_HPP +#define CATLASS_GEMM_TILE_COPY_L1_TO_L0B_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/gemm/gemm_type.hpp" +#include "catlass/gemm/tile/tile_copy_tla.hpp" +#include "tla/tensor.hpp" + +namespace Catlass::Gemm::Tile { + +template +struct CopyL1ToL0B { + static_assert(DEPENDENT_FALSE, "Unsupported copy l1 to l0, can not find the specialization."); +}; + +//////////////////////////////////////// +/// new add gemm +template +struct CopyL1ToL0B, + Catlass::Gemm::GemmType> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + CATLASS_DEVICE + CopyL1ToL0B() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutSrc.orgShape(1))); + loadDataParams.srcStride = 1; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { // K N + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } + } +}; + +template +struct CopyL1ToL0B, + Catlass::Gemm::GemmType> { + using Element = float; + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + CATLASS_DEVICE + CopyL1ToL0B() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutSrc.orgShape(1))); + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = static_cast(CeilDiv(layoutDst.orgShape(1))) - 1; + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { // K N + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(1) * 2], srcTensor[i * layoutSrc.stride(1)], + loadDataParams); + } + } +}; + +template +struct CopyL1ToL0B, + Catlass::Gemm::GemmType> { + using Element = int8_t; + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + CATLASS_DEVICE + CopyL1ToL0B() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL / 2; + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1) * 2], + loadDataParams); + } + } +}; + +template +struct CopyL1ToL0B, + Catlass::Gemm::GemmType> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyL1ToL0B() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } + } +}; +///////////////////////////////////////////// + +//////////////////////////////////////////// +/// new add gemv +template +struct CopyL1ToL0B, + Catlass::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyL1ToL0B() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(1)); + loadDataParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(1) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutDst.shape(3); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(3)], srcTensor[i * layoutSrc.stride(3)], loadDataParams); + } + } +}; + +template +struct CopyL1ToL0B, + Catlass::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::nN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyL1ToL0B() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = layoutDst.shape(1) * layoutDst.shape(3); + loadDataParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(1) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + AscendC::LoadData(dstTensor, srcTensor, loadDataParams); + }; +}; + +template +struct CopyL1ToL0B, + Catlass::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::nN; + using Element = float; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyL1ToL0B() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(0))); + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = CeilDiv(layoutDst.orgShape(0)) - 1; + + for (uint32_t i = 0; i < CeilDiv<2 * ELE_NUM_PER_C0>(layoutDst.orgShape(1)); i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(3) * 2], srcTensor[i * layoutSrc.stride(3)], + loadDataParams); + } + }; +}; + +template +struct CopyL1ToL0B, + Catlass::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::nZ; + using Element = int8_t; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyL1ToL0B() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(0))); + loadDataParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_FRACTAL / 2; + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(1)); i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(3)], srcTensor[i * layoutSrc.stride(3) * 2], + loadDataParams); + } + } +}; +//////////////////////////////////////////// + +/// Partial specialization for int8_t, zN in and nZ out. +template +struct CopyL1ToL0B> { + using Element = int8_t; + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyL1ToL0B() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL / 2; + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1) * 2], + loadDataParams); + } + } +}; + +/// Partial specialization for float, zN in and nZ out. +template +struct CopyL1ToL0B> { + using Element = float; + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyL1ToL0B() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + constexpr uint8_t PAD_LIST[4] = {0, 0, 0, 0}; + uint16_t l1K = layoutSrc.shape(0) * layoutSrc.shape(1); + uint16_t l1N = layoutSrc.shape(2) * layoutSrc.shape(3); + uint16_t l0K = layoutDst.shape(0) * layoutDst.shape(1); + uint16_t l0N = layoutDst.shape(2) * layoutDst.shape(3); + // K, N need to be 16 aligned for f32 + uint16_t l1KAlign = RoundUp(l1K); + uint16_t l1NAlign = RoundUp(l1N); + uint16_t l0KAlign = RoundUp(l0K); + uint16_t l0NAlign = RoundUp(l0N); + AscendC::SetFmatrix(1, l1KAlign, PAD_LIST, AscendC::FmatrixMode::FMATRIX_RIGHT); + static constexpr AscendC::IsResetLoad3dConfig config = {false, false}; + AscendC::LoadData3DParamsV2 loadDataParams; + loadDataParams.kExtension = l0NAlign; + loadDataParams.mExtension = l0KAlign; + loadDataParams.channelSize = l1NAlign; + loadDataParams.fMatrixCtrl = true; + + AscendC::LoadData(dstTensor, srcTensor, loadDataParams); + } +}; + +/// Partial specialization for zN in and nZ out. +template +struct CopyL1ToL0B> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyL1ToL0B() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } + } +}; + +/// Partial specialization for nZ in and nZ out. (Transpose B) +template +struct CopyL1ToL0B> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + CATLASS_DEVICE + CopyL1ToL0B() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + if (layoutSrc.shape(3) == layoutDst.shape(3)) { + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(1) * layoutDst.shape(3)); + loadDataParams.srcStride = 1; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + AscendC::LoadData(dstTensor, srcTensor, loadDataParams); + } else { + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], + loadDataParams); + } + } + } +}; + +///////////////////////////////////////////TileCopyTla////////////////////////////////////////////////////// +/// Partial specialization for CopyL1ToL0B, AtlasA2, zN in and nZ out. +template +struct TileCopyTla, LayoutSrc, CoordSrc, AscendC::TPosition::A1>, + tla::Tensor, LayoutDst, CoordDst, AscendC::TPosition::B2>, + std::enable_if_t::value && + tla::detail::isnZ::value>> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(ElementSrc); + + // Methods + + CATLASS_DEVICE + TileCopyTla() {}; + + template + CATLASS_DEVICE void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + static_assert( + tla::detail::iszN::value && + tla::detail::isnZ::value && + TensorSrc::position == AscendC::TPosition::A1 && TensorDst::position == AscendC::TPosition::B2, + "The input parameters do not match. TensorSrc must be L1 and zN, while TensorDst must be L0B and nZ"); + + const uint32_t srcOuterStrideRow = tla::get<0, 1>(srcTensor.stride()); + const uint32_t srcOuterStrideCol = tla::get<1, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = tla::get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterShapeCol = tla::get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = tla::get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + + auto dstOffset = dstTensor.layout()(dstTensor.coord()); + auto srcOffset = srcTensor.layout()(srcTensor.coord()); + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadData(dstTensor.data()[dstOffset + i * dstOuterStrideRow], + srcTensor.data()[srcOffset + i * srcOuterStrideRow], loadDataParams); + } + } +}; + +/// Partial specialization for CopyL1ToL0B, AtlasA2, nZ in and nZ out. (Transpose B) +template +struct TileCopyTla, LayoutSrc, CoordSrc, AscendC::TPosition::A1>, + tla::Tensor, LayoutDst, CoordDst, AscendC::TPosition::B2>, + std::enable_if_t::value && + tla::detail::isnZ::value>> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(ElementSrc); + + // Methods + + CATLASS_DEVICE + TileCopyTla() {}; + + template + CATLASS_DEVICE void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + static_assert( + tla::detail::isnZ::value && + tla::detail::isnZ::value && + TensorSrc::position == AscendC::TPosition::A1 && TensorDst::position == AscendC::TPosition::B2, + "The input parameters do not match. TensorSrc must be L1 and nZ, while TensorDst must be L0B and nZ"); + + const uint32_t srcOuterStrideRow = tla::get<0, 1>(srcTensor.stride()); + const uint32_t srcOuterStrideCol = tla::get<1, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = tla::get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterShapeCol = tla::get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = tla::get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + auto dstOffset = dstTensor.layout()(dstTensor.coord()); + auto srcOffset = srcTensor.layout()(srcTensor.coord()); + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadData(dstTensor.data()[dstOffset + i * dstOuterStrideRow], + srcTensor.data()[srcOffset + i * srcOuterStrideRow], loadDataParams); + } + } +}; + +/// Partial specialization for CopyL1ToL0B, AtlasA2, int8_t, zN in and nZ out. +template +struct TileCopyTla< + Arch::AtlasA2, tla::Tensor, LayoutSrc, CoordSrc, AscendC::TPosition::A1>, + tla::Tensor, LayoutDst, CoordDst, AscendC::TPosition::B2>, + std::enable_if_t::value && tla::detail::isnZ::value>> { + using Element = int8_t; + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + CATLASS_DEVICE + TileCopyTla() {}; + + template + CATLASS_DEVICE void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + static_assert(std::is_same_v && + std::is_same_v && + tla::detail::iszN::value && + tla::detail::isnZ::value && + TensorSrc::position == AscendC::TPosition::A1 && + TensorDst::position == AscendC::TPosition::B2, + "The input parameters do not match. TensorSrc must be int8_t, L1 and zN, " + "while TensorDst must be int8_t, L0B and nZ"); + + const uint32_t srcOuterShapeCol = tla::get<1, 1>(srcTensor.shape()); + const uint32_t srcOuterStrideRow = tla::get<0, 1>(srcTensor.stride()); + const uint32_t srcOuterStrideCol = tla::get<1, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = tla::get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = tla::get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = srcOuterShapeCol; + loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL / 2; + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + auto dstOffset = dstTensor.layout()(dstTensor.coord()); + auto srcOffset = srcTensor.layout()(srcTensor.coord()); + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadDataWithTranspose(dstTensor.data()[dstOffset + i * dstOuterStrideRow], + srcTensor.data()[srcOffset + i * srcOuterStrideRow * 2], loadDataParams); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Gemm::Tile + +#endif // CATLASS_GEMM_TILE_COPY_L1_TO_L0B_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_ub_to_gm.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_ub_to_gm.hpp new file mode 100644 index 00000000..e95843c6 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/copy_ub_to_gm.hpp @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_TILE_COPY_UB_TO_GM_HPP +#define CATLASS_GEMM_TILE_COPY_UB_TO_GM_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/gemm/tile/tile_copy_tla.hpp" +#include "tla/tensor.hpp" + +namespace Catlass::Gemm::Tile { + +/// Partial specialization for AtlasA2, RowMajor in and RowMajor out. +template +struct TileCopyTla< + Arch::AtlasA2, tla::Tensor, LayoutSrc, CoordSrc, AscendC::TPosition::VECCALC>, + tla::Tensor, LayoutDst, CoordDst, AscendC::TPosition::GM>, + std::enable_if_t::value && tla::detail::isRowMajor::value>> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + CATLASS_DEVICE + TileCopyTla() {}; + + template + CATLASS_DEVICE void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + static_assert(tla::detail::isRowMajor::value && + tla::detail::isRowMajor::value && + TensorSrc::position == AscendC::TPosition::VECCALC && + TensorDst::position == AscendC::TPosition::GM, + "The input parameters do not match. TensorSrc must be GM and RowMajor, " + "while TensorDst must be UB and RowMajor"); + + AscendC::DataCopyExtParams dataCopyParams( + tla::get<0>(dstTensor.shape()), tla::get<1>(dstTensor.shape()) * sizeof(ElementSrc), + (tla::get<0>(srcTensor.stride()) - tla::get<1>(srcTensor.shape())) / ELE_NUM_PER_C0, + (tla::get<0>(dstTensor.stride()) - tla::get<1>(dstTensor.shape())) * sizeof(ElementSrc), 0); + auto dstOffset = dstTensor.layout()(dstTensor.coord()); + auto srcOffset = srcTensor.layout()(srcTensor.coord()); + AscendC::DataCopyPad(dstTensor.data()[dstOffset], srcTensor.data()[srcOffset], dataCopyParams); + }; +}; + +/// Partial specialization for AtlasA2, RowMajor in and PaddingRowMajor out. +template +struct TileCopyTlaExt, LayoutSrc, CoordSrc, AscendC::TPosition::VECCALC>, + tla::Tensor, LayoutDst, CoordDst, AscendC::TPosition::GM>, + layout::RowMajor, layout::PaddingRowMajor> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + CATLASS_DEVICE + TileCopyTlaExt() {}; + + template + CATLASS_DEVICE void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + static_assert(tla::detail::isRowMajor::value && + TensorSrc::position == AscendC::TPosition::VECCALC && + TensorDst::position == AscendC::TPosition::GM, + "The input parameters do not match. TensorSrc must be GM and PaddingRowMajor, " + "while TensorDst must be UB and RowMajor"); + + AscendC::DataCopyExtParams dataCopyParams( + tla::get<1, 1>(dstTensor.shape()), tla::get<1, 0>(dstTensor.shape()) * sizeof(ElementSrc), + (tla::get<0>(srcTensor.stride()) - tla::get<1>(srcTensor.shape())) / ELE_NUM_PER_C0, + (tla::get<1, 1>(dstTensor.stride()) - tla::get<1, 0>(dstTensor.shape())) * sizeof(ElementSrc), 0); + auto dstOffset = dstTensor.layout()(dstTensor.coord()); + auto srcOffset = srcTensor.layout()(srcTensor.coord()); + AscendC::DataCopyPad(dstTensor.data()[dstOffset], srcTensor.data()[srcOffset], dataCopyParams); + }; +}; + +} // namespace Catlass::Gemm::Tile + +#endif // CATLASS_GEMM_TILE_COPY_UB_TO_GM_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/tile_copy.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/tile_copy.hpp new file mode 100644 index 00000000..08e34eaf --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/tile_copy.hpp @@ -0,0 +1,252 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_TILE_TILE_COPY_HPP +#define CATLASS_GEMM_TILE_TILE_COPY_HPP + +#include +#include "catlass/catlass.hpp" +#include "catlass/detail/tag_to_layout.hpp" +#include "tla/tensor.hpp" +#include "catlass/gemm/tile/copy_gm_to_l1.hpp" +#include "catlass/gemm/tile/copy_l0c_to_gm.hpp" +#include "catlass/gemm/tile/copy_l1_to_l0a.hpp" +#include "catlass/gemm/tile/copy_l1_to_l0b.hpp" +#include "catlass/gemm/tile/copy_l1_to_bt.hpp" +#include "catlass/gemm/tile/copy_gm_to_ub.hpp" +#include "catlass/gemm/tile/copy_ub_to_gm.hpp" +#include "catlass/gemm/helper.hpp" + +namespace Catlass::Gemm::Tile { + +template < + /// Tag indicating architecture + class ArchTag, + /// GemmType for A matrix operand + class AType, + /// GemmType type for B matrix operand + class BType, + /// GemmType type for C matrix operand + class CType, + /// GemmType type for Bias operand + class BiasType = void> +struct TileCopy { + using ElementA = typename AType::Element; + using ElementB = typename BType::Element; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + using CopyGmToL1A = Gemm::Tile::CopyGmToL1; + using CopyGmToL1B = Gemm::Tile::CopyGmToL1; + using CopyL1ToL0A = Gemm::Tile::CopyL1ToL0A::L1AType>; + using CopyL1ToL0B = Gemm::Tile::CopyL1ToL0B::L1BType>; + using CopyL0CToGm = Gemm::Tile::CopyL0CToGm; + using BiasTypeSelector = helper::L1BiasTypeSelector; + using CopyGmToL1Bias = std::conditional_t< + std::is_same_v, void, + Gemm::Tile::CopyGmToL1>; + using CopyL1ToBT = std::conditional_t< + std::is_same_v, void, + Gemm::Tile::CopyL1ToBT>; +}; + +template < + /// Tag indicating architecture + class ArchTag, class ElementA_, class LayoutTagA, class ElementB_, class LayoutTagB, class ElementC_, + class LayoutTagC, class ElementBias = void, class LayoutTagBias = void, class L0CCopyMode = CopyToGM> +struct PackedTileCopyTla { + using ElementA = ElementA_; + using ElementB = ElementB_; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + using LayoutTagL1A = typename helper::L1ATypeSelector>::L1AType::Layout; + using LayoutTagL1B = typename helper::L1BTypeSelector>::L1BType::Layout; + using LayoutTagL0A = layout::zZ; + using LayoutTagL0B = layout::nZ; + + using LayoutA = detail::TagToLayout_t; + using LayoutB = detail::TagToLayout_t; + using LayoutC = detail::TagToLayout_t; + + using LayoutL1A = detail::TagToLayout_t; + using LayoutL1B = detail::TagToLayout_t; + using LayoutL0A = detail::TagToLayout_t; + using LayoutL0B = detail::TagToLayout_t; + using LayoutL0C = typename detail::LayoutL0C; + + using TensorL1A = + tla::Tensor, LayoutL1A, tla::Coord, AscendC::TPosition::A1>; + using TensorL1B = + tla::Tensor, LayoutL1B, tla::Coord, AscendC::TPosition::A1>; + using TensorL0A = + tla::Tensor, LayoutL0A, tla::Coord, AscendC::TPosition::A2>; + using TensorL0B = + tla::Tensor, LayoutL0B, tla::Coord, AscendC::TPosition::B2>; + using TensorL0C = tla::Tensor, LayoutL0C, tla::Coord, + AscendC::TPosition::CO1>; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + template + using CopyGmToL1A = Gemm::Tile::TileCopyTla; + + template + using CopyGmToL1B = Gemm::Tile::TileCopyTla; + + using CopyL1ToL0A = Gemm::Tile::TileCopyTla; + using CopyL1ToL0B = Gemm::Tile::TileCopyTla; + + template + using CopyL0CToGm = Gemm::Tile::CopyL0CToGmTla; +}; + +template < + /// Tag indicating architecture + class ArchTag, class TensorA, class LayoutTagA, class TensorB, class LayoutTagB, class TensorC, class LayoutTagC, + class TensorBias = void, class LayoutTagBias = void, bool IS_PADDING_A = false, bool IS_PADDING_B = false> +struct PaddingPackedTileCopyTla { + static_assert(std::is_same_v || std::is_same_v, + "Unsupported layout, only can be RowMajor and ColumnMajor"); + static_assert(std::is_same_v || std::is_same_v, + "Unsupported layout, only can be RowMajor and ColumnMajor"); + using ElementA = typename TensorA::Element; + using ElementB = typename TensorB::Element; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + using LayoutTagL1A = typename helper::L1ATypeSelector>::L1AType::Layout; + using LayoutTagL1B = typename helper::L1BTypeSelector>::L1BType::Layout; + using LayoutTagL0A = layout::zZ; + using LayoutTagL0B = layout::nZ; + + using LayoutL1A = detail::TagToLayout_t; + using LayoutL1B = detail::TagToLayout_t; + using LayoutL0A = detail::TagToLayout_t; + using LayoutL0B = detail::TagToLayout_t; + using LayoutL0C = typename detail::LayoutL0C; + + using TensorL1A = + tla::Tensor, LayoutL1A, tla::Coord, AscendC::TPosition::A1>; + using TensorL1B = + tla::Tensor, LayoutL1B, tla::Coord, AscendC::TPosition::A1>; + using TensorL0A = + tla::Tensor, LayoutL0A, tla::Coord, AscendC::TPosition::A2>; + using TensorL0B = + tla::Tensor, LayoutL0B, tla::Coord, AscendC::TPosition::B2>; + using TensorL0C = tla::Tensor, LayoutL0C, tla::Coord, + AscendC::TPosition::CO1>; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + using LayoutPaddingTagA = std::conditional_t, layout::PaddingRowMajor, + layout::PaddingColumnMajor>; + using LayoutPaddingTagB = std::conditional_t, layout::PaddingRowMajor, + layout::PaddingColumnMajor>; + + using CopyGmToL1A = + std::conditional_t, + Gemm::Tile::TileCopyTlaExt>; + using CopyGmToL1B = + std::conditional_t, + Gemm::Tile::TileCopyTlaExt>; + + using CopyL1ToL0A = Gemm::Tile::TileCopyTla; + using CopyL1ToL0B = Gemm::Tile::TileCopyTla; + using CopyL0CToGm = Gemm::Tile::CopyL0CToGmTla; +}; +/////////////////////////////////// +/// new add +template < + /// Tag indicating architecture + class ArchTag, + /// GemmType for A matrix operand + class AType, + /// GemmType type for B matrix operand + class BType, + /// GemmType type for C matrix operand + class CType, + /// GemmTpe type for Bias operand + class BiasType = void> +struct TileCopyGemm { + using ElementA = typename AType::Element; + using ElementB = typename BType::Element; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + // change structural + using L1AType = typename Gemm::helper::L1AndL0TypeSelectorGemm::L1AType; + using L1BType = typename Gemm::helper::L1AndL0TypeSelectorGemm::L1BType; + using L0AType = typename Gemm::helper::L1AndL0TypeSelectorGemm::L0AType; + using L0BType = typename Gemm::helper::L1AndL0TypeSelectorGemm::L0BType; + + using CopyGmToL1A = Gemm::Tile::CopyGmToL1; + using CopyGmToL1B = Gemm::Tile::CopyGmToL1; + using CopyL1ToL0A = Gemm::Tile::CopyL1ToL0A; + using CopyL1ToL0B = Gemm::Tile::CopyL1ToL0B; + using CopyL0CToGm = Gemm::Tile::CopyL0CToGm; +}; +////////////////////////////// +template < + /// Tag indicating architecture + class ArchTag, + /// GemmType for A matrix operand + class AType, + /// GemmType type for B matrix operand + class BType, + /// GemmType type for C matrix operand + class CType, + /// GemmType type for Bias operand + class BiasType> +struct ConvTileCopy { + using ElementA = typename AType::Element; + using ElementB = typename BType::Element; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + using CopyGmToL1A = Gemm::Tile::CopyGmToL1; + using CopyGmToL1B = Gemm::Tile::CopyGmToL1; + using CopyL1ToL0A = Gemm::Tile::CopyL1ToL0A::L1AType>; + using CopyL1ToL0B = Gemm::Tile::CopyL1ToL0B::L1BType>; + using CopyL0CToGm = Gemm::Tile::CopyL0CToGm; + using BiasTypeSelector = helper::L1BiasTypeSelector; + using CopyGmToL1Bias = std::conditional_t< + std::is_same_v, void, + Gemm::Tile::CopyGmToL1>; + using CopyL1ToBT = std::conditional_t< + std::is_same_v, void, + Gemm::Tile::CopyL1ToBT>; +}; + +// fixpipe开启relu开关 +template < + /// Tag indicating architecture + class ArchTag, + /// GemmType for A matrix operand + class AType, + /// GemmType type for B matrix operand + class BType, + /// GemmType type for C matrix operand + class CType, + /// GemmType type for Bias operand + class BiasType = void> +struct ReluTileCopy : public TileCopy { + // 重写 CopyL0CToGm + using ElementAccumulator = typename TileCopy::ElementAccumulator; + using CopyL0CToGm = Gemm::Tile::CopyL0CToGm; +}; + +} // namespace Catlass::Gemm::Tile + +#endif // CATLASS_GEMM_TILE_TILE_COPY_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/tile_copy_tla.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/tile_copy_tla.hpp new file mode 100644 index 00000000..169cc205 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/tile_copy_tla.hpp @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_TILE_TILE_COPY_TLA_HPP +#define CATLASS_GEMM_TILE_TILE_COPY_TLA_HPP + +#include "catlass/catlass.hpp" + +namespace Catlass::Gemm::Tile { + +template +struct TileCopyTla { + static_assert(DEPENDENT_FALSE, "Unsupported TileCopyTla, can not find the specialization."); +}; + +// Extended template for TileCopyTla that supports manually specifying LayoutTagSrc and LayoutTagDst. +// Users can specialize the copy class by LayoutTagSrc and LayoutTagDst. +template +struct TileCopyTlaExt { + static_assert(DEPENDENT_FALSE, "Unsupported TileCopyTlaExt, can not find the specialization."); +}; + +} // namespace Catlass::Gemm::Tile + +#endif // CATLASS_GEMM_TILE_TILE_COPY_TLA_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/tile_mmad.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/tile_mmad.hpp new file mode 100644 index 00000000..a7080833 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm/tile/tile_mmad.hpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_TILE_TILE_MMAD_HPP +#define CATLASS_GEMM_TILE_TILE_MMAD_HPP + +#include "catlass/catlass.hpp" +#include "catlass/gemm/helper.hpp" +#include "tla/tensor.hpp" +namespace Catlass::Gemm::Tile { + +/////////////////////////////////////////////////////////// + +template < + /// Tag indicating architecture + class ArchTag_, + /// GemmType for A matrix operand + class AType_, + /// GemmType type for B matrix operand + class BType_, + /// GemmType type for Bias operand + class BiasType_> +struct TileMmad { + using ElementA = typename AType_::Element; + using ElementB = typename BType_::Element; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + // Methods + + CATLASS_DEVICE + TileMmad() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &l0CTensor, + AscendC::LocalTensor const &l0ATensor, AscendC::LocalTensor const &l0BTensor, + uint32_t m, uint32_t n, uint32_t k, bool initC = true, uint8_t unitFlag = 0) + { + AscendC::MmadParams mmadParams; + mmadParams.m = m; + mmadParams.n = n; + mmadParams.k = k; + mmadParams.unitFlag = unitFlag; + mmadParams.cmatrixInitVal = initC; + if constexpr (std::is_same_v && std::is_same_v) { + mmadParams.kDirectionAlign = true; + } + + AscendC::Mmad(l0CTensor, l0ATensor, l0BTensor, mmadParams); + + const uint32_t PIPE_M_BARRIER_THRESHOLD = 10; + if ((m / C0_NUM_PER_FRACTAL) * (n / C0_NUM_PER_FRACTAL) < PIPE_M_BARRIER_THRESHOLD) { + AscendC::PipeBarrier(); + } + } + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &l0CTensor, + AscendC::LocalTensor const &l0ATensor, AscendC::LocalTensor const &l0BTensor, + AscendC::LocalTensor const &l0BiasTensor, uint32_t m, uint32_t n, uint32_t k, + bool initC = true, uint8_t unitFlag = 0) + { + AscendC::MmadParams mmadParams; + mmadParams.m = m; + mmadParams.n = n; + mmadParams.k = k; + mmadParams.unitFlag = unitFlag; + mmadParams.cmatrixInitVal = false; + if constexpr (std::is_same_v && std::is_same_v) { + mmadParams.kDirectionAlign = true; + } + + AscendC::Mmad(l0CTensor, l0ATensor, l0BTensor, l0BiasTensor, mmadParams); + + const uint32_t PIPE_M_BARRIER_THRESHOLD = 10; + if ((m / C0_NUM_PER_FRACTAL) * (n / C0_NUM_PER_FRACTAL) < PIPE_M_BARRIER_THRESHOLD) { + AscendC::PipeBarrier(); + } + } +}; + +///////////////////////////////////////////TileMmadTla///////////////////////////////////////////////// + +template < + /// Tag indicating architecture + class ArchTag_, + /// tla::Tensor type for A matrix operand + class TensorA_, + /// tla::Tensor type for B matrix operand + class TensorB_, + /// tla::Tensor type for C matrix operand + class TensorC_, + /// tla::Tensor type for Bias operand + class TensorBias_ = void> +struct TileMmadTla { + // Methods + + CATLASS_DEVICE + TileMmadTla() {} + + template + CATLASS_DEVICE void operator()(TensorC const &l0CTensor, TensorA const &l0ATensor, TensorB const &l0BTensor, + uint32_t m, uint32_t n, uint32_t k, bool initC = true, uint8_t unitFlag = 0) + { + AscendC::MmadParams mmadParams; + mmadParams.m = m; + mmadParams.n = n; + mmadParams.k = k; + mmadParams.unitFlag = unitFlag; + mmadParams.cmatrixInitVal = initC; + + AscendC::Mmad(l0CTensor.data(), l0ATensor.data(), l0BTensor.data(), mmadParams); + + const uint32_t PIPE_M_BARRIER_THRESHOLD = 10; + if ((m / C0_NUM_PER_FRACTAL) * (n / C0_NUM_PER_FRACTAL) < PIPE_M_BARRIER_THRESHOLD) { + AscendC::PipeBarrier(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Gemm::Tile + +#endif // CATLASS_GEMM_TILE_TILE_MMAD_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm_coord.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm_coord.hpp new file mode 100644 index 00000000..869b5a64 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemm_coord.hpp @@ -0,0 +1,157 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMM_COORD_HPP +#define CATLASS_GEMM_COORD_HPP + +#include "catlass/coord.hpp" + +namespace Catlass { + +/// Shape of a matrix multiply-add operation +template < + /// Rows of matrix product + uint32_t M_ = 1, + /// Columns of matrix product + uint32_t N_ = 1, + /// Inner dimension of matrix product + uint32_t K_ = 1> +struct GemmShape { + static constexpr uint32_t M = M_; + static constexpr uint32_t N = N_; + static constexpr uint32_t K = K_; + + static constexpr int64_t MN = M * N; + static constexpr int64_t MK = M * K; + static constexpr int64_t KN = N * K; + static constexpr int64_t MNK = M * N * K; + + static constexpr int64_t COUNT = MNK; + + /// Returns a Coord object + CATLASS_HOST_DEVICE + static Coord<3> ToCoord() + { + return MakeCoord(M, N, K); + } + + CATLASS_HOST_DEVICE + static Coord<2> ToCoordMN() + { + return MakeCoord(M, N); + } + + CATLASS_HOST_DEVICE + static Coord<2> ToCoordMK() + { + return MakeCoord(M, K); + } + + CATLASS_HOST_DEVICE + static Coord<2> ToCoordKN() + { + return MakeCoord(K, N); + } +}; + +/// GemmCoord is a structure derived from Coord<3> that specifies a location within the +/// coordinate space of a Gemm problem. +struct GemmCoord : public Coord<3, uint32_t> { + /// Integer-valued index + using Index = uint32_t; + + /// Base type is a Coord of rank=3 + using Base = Coord<3, Index>; + + /// Gemm M dimension - rows of the output C matrix + static constexpr int M_INDEX = 0; + + /// Gemm N dimension - columns of the output C matrix + static constexpr int N_INDEX = 1; + + /// Gemm K dimension - inner dimension of the Gemm problem + static constexpr int K_INDEX = 2; + + /// Default ctor + CATLASS_HOST_DEVICE + GemmCoord() {} + + /// Constructs from Coord<3> and a batch + CATLASS_HOST_DEVICE + GemmCoord(Coord<3, Index> const &coord) : Base(coord) {} + + /// Helper to construct from a K, N, M, batch variables + CATLASS_HOST_DEVICE + GemmCoord(Index m, Index n, Index k) : Base(MakeCoord(m, n, k)) {} + + /// Returns the Gemm M coordinate + CATLASS_HOST_DEVICE + Index const &m() const + { + return this->At(M_INDEX); + } + + /// Returns reference to the Gemm M coordinate + CATLASS_HOST_DEVICE + Index &m() + { + return this->At(M_INDEX); + } + + /// Returns the Gemm N coordinate + CATLASS_HOST_DEVICE + Index const &n() const + { + return this->At(N_INDEX); + } + + /// Returns reference to the Gemm N coordinate + CATLASS_HOST_DEVICE + Index &n() + { + return this->At(N_INDEX); + } + + /// Returns the Gemm K coordinate + CATLASS_HOST_DEVICE + Index const &k() const + { + return this->At(K_INDEX); + } + + /// Returns reference to the Gemm K coordinate + CATLASS_HOST_DEVICE + Index &k() + { + return this->At(K_INDEX); + } + + CATLASS_HOST_DEVICE + auto GetCoordMN() const + { + return this->GetCoordByAxis(); + } + + CATLASS_HOST_DEVICE + auto GetCoordMK() const + { + return this->GetCoordByAxis(); + } + + CATLASS_HOST_DEVICE + auto GetCoordKN() const + { + return this->GetCoordByAxis(); + } +}; + +} // namespace Catlass + +#endif // CATLASS_GEMM_COORD_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/block/block_gemv.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/block/block_gemv.hpp new file mode 100644 index 00000000..b5a27112 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/block/block_gemv.hpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMV_BLOCK_BLOCK_GEMV_HPP +#define CATLASS_GEMV_BLOCK_BLOCK_GEMV_HPP + +#include "catlass/catlass.hpp" +namespace Catlass::Gemv::Block { + +template +struct BlockGemv { + static_assert(DEPENDENT_FALSE, "BlockGemv is not implemented for this DispatchPolicy"); +}; +} // namespace Catlass::Gemv::Block + +#include "catlass/gemv/block/block_gemv_aiv.hpp" +#include "catlass/gemv/block/block_gemv_aic.hpp" + +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/block/block_gemv_aic.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/block/block_gemv_aic.hpp new file mode 100644 index 00000000..fc7d4b5b --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/block/block_gemv_aic.hpp @@ -0,0 +1,345 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMV_BLOCK_BLOCK_AIC_HPP +#define CATLASS_GEMV_BLOCK_BLOCK_AIC_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/tile/tile_copy.hpp" +#include "catlass/gemm/tile/tile_mmad.hpp" +#include "catlass/gemv_coord.hpp" + +#include "catlass/gemv/helper.hpp" + +namespace Catlass::Gemv::Block { + +template +struct BlockGemv, L1TileShape_, L0TileShape_, AType_, + XType_, YType_, BiasType_, TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = Gemm::MmadAtlasA2Preload; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementX = typename XType_::Element; + using LayoutX = typename XType_::Layout; + using ElementY = typename YType_::Element; + using LayoutY = typename YType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutXInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutAInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutXInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutAInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutYInL0 = layout::zN; + + using L1AAlignHelper = Gemv::helper::L1AlignHelper; + using L1XAlignHelper = Gemv::helper::L1AlignHelper; + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K; + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t L1A_SIZE = 16 * L1TileShape::N * sizeof(ElementX); + static constexpr uint32_t L1B_SIZE = L1TileShape::M * L1TileShape::N * sizeof(ElementA); + static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; + static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; + static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; + static constexpr uint32_t L0C_TILE_SIZE = L0TileShape::M * L0TileShape::N; + static constexpr uint32_t L0C_TILE_NUM = L0C_SIZE / L0C_TILE_SIZE / sizeof(ElementAccumulator); + static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; + static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES; + + // Check L1TileShape + static_assert((L1A_SIZE * STAGES + L1B_SIZE * STAGES) <= ArchTag::L1_SIZE, "L1TileShape exceeding the L1 space!"); + + static constexpr uint32_t L0A_TILE_SIZE = L1XAlignHelper::M_ALIGNED * L0TileShape::N * sizeof(ElementX); + static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::M * L0TileShape::N * sizeof(ElementA); + static_assert((L0A_TILE_SIZE * STAGES) <= L0A_SIZE, "L0TileShape exceeding the L0A space!"); + static_assert((L0B_TILE_SIZE * STAGES) <= L0B_SIZE, "L0TileShape exceeding the L0B space!"); + + /// Construct + CATLASS_DEVICE + BlockGemv(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + uint32_t l1AOffset = l1BufAddrStart; + uint32_t l1BOffset = l1BufAddrStart + L1A_SIZE * STAGES; + // Init buffers + for (uint32_t i = 0; i < STAGES; i++) { + // Assign L1/L0A/L0B space for each stages + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1A_SIZE * i); + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_SIZE * i); + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_PINGPONG_BUF_SIZE * i); + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_PINGPONG_BUF_SIZE * i); + + l1AEventList[i] = i; + l1BEventList[i] = i + STAGES; + l0AEventList[i] = i; + l0BEventList[i] = i + STAGES; + + // The event id that needs to be set before the loop + AscendC::SetFlag(l1AEventList[i]); + AscendC::SetFlag(l1BEventList[i]); + AscendC::SetFlag(l0AEventList[i]); + AscendC::SetFlag(l0BEventList[i]); + } + l0CTensor = resource.l0CBuf.template GetBufferByByte(0); + } + + /// Destructor + CATLASS_DEVICE + ~BlockGemv() + { + for (uint32_t i = 0; i < STAGES; i++) { + AscendC::WaitFlag(l1AEventList[i]); + AscendC::WaitFlag(l1BEventList[i]); + AscendC::WaitFlag(l0AEventList[i]); + AscendC::WaitFlag(l0BEventList[i]); + } + } + + /// Perform a block-scoped vector-matrix multiply-accumulate + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &gmBlockX, LayoutX const &layoutX, + AscendC::GlobalTensor const &gmBlockA, LayoutA const &layoutA, + AscendC::GlobalTensor const &gmBlockY, LayoutY const &layoutY, + AscendC::GlobalTensor const &gmNextBlockX, + AscendC::GlobalTensor const &gmNextBlockA, GemvCoord const &actualShape, + GemvCoord const &actualShapeNext, bool isFirstBlock, bool hasNextBlock, uint32_t singleIdx) + { + auto layoutXInL1 = LayoutXInL1::template MakeLayout(L1XAlignHelper::M_ALIGNED, L1TileShape::N); + auto layoutAInL1 = LayoutAInL1::template MakeLayout(L1TileShape::M, L1TileShape::N); + auto layoutInL0C = LayoutYInL0::MakeLayoutInL0C(MatrixCoord(L1XAlignHelper::M_ALIGNED, actualShape.m())); + + uint32_t nTileCount = CeilDiv(actualShape.n()); + uint32_t nTileCountNext = CeilDiv(actualShapeNext.n()); + + // Optimize points:ShuffleK + uint32_t startTileIdx = 0; + if constexpr (ENABLE_SHUFFLE_K_) { + startTileIdx = AscendC::GetBlockIdx(); + } + uint32_t firstTileIdx = startTileIdx % nTileCount; + uint32_t lastTileIdx = (startTileIdx + nTileCount - 1) % nTileCount; + uint32_t firstTileIdxNext = startTileIdx % nTileCountNext; + + uint32_t nActual = + (firstTileIdx < nTileCount - 1) ? L1TileShape::N : (actualShape.n() - firstTileIdx * L1TileShape::N); + uint32_t nRound = RoundUp(nActual); + + // main loop + for (uint32_t nLoopIdx = 0; nLoopIdx < nTileCount; nLoopIdx++) { + uint32_t shuffleKIdx = (startTileIdx + nLoopIdx) % nTileCount; + if (shuffleKIdx == firstTileIdx && isFirstBlock) { + MatrixCoord gmTileAOffset{0, shuffleKIdx * L1TileShape::N}; + uint32_t gmTilexOffset{shuffleKIdx * L1TileShape::N}; + + auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTilex = gmBlockX[gmTilexOffset]; + + // load first vector x tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListId]); + auto layoutTilex = layoutX.GetTileLayout(MakeCoord(nRound)); + copyGmToL1A(l1ATensorList[l1ListId], gmTilex, layoutXInL1, layoutTilex); + AscendC::SetFlag(l1AEventList[l1ListId]); + + // load first matrix A tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListId]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), nRound)); + copyGmToL1B(l1BTensorList[l1ListId], gmTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1BEventList[l1ListId]); + } + + uint32_t l1ListIdNext = (l1ListId + 1) % STAGES; + uint32_t nActualNext{0}; + uint32_t nRoundNext{0}; + + // preload next tile from GM to L1 + if (shuffleKIdx != lastTileIdx) { + uint32_t shuffleKIdxNext = (startTileIdx + nLoopIdx + 1) % nTileCount; + nActualNext = (shuffleKIdxNext < nTileCount - 1) ? L1TileShape::N + : (actualShape.n() - shuffleKIdxNext * L1TileShape::N); + nRoundNext = RoundUp(nActualNext); + + // Get L1 tensor + auto l1ATensor = l1ATensorList[l1ListIdNext]; + auto l1BTensor = l1BTensorList[l1ListIdNext]; + + // Get GM tile + MatrixCoord gmTileAOffset{0, shuffleKIdxNext * L1TileShape::N}; + uint32_t gmTilexOffset{shuffleKIdxNext * L1TileShape::N}; + + auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTilex = gmBlockX[gmTilexOffset]; + + // load vector x tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + auto layoutTilex = layoutX.GetTileLayout(MakeCoord(nRoundNext)); + + copyGmToL1A(l1ATensor, gmTilex, layoutXInL1, layoutTilex); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + + // load Matrix A tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), nRoundNext)); + + copyGmToL1B(l1BTensor, gmTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + } + if (shuffleKIdx == lastTileIdx && hasNextBlock) { + // Get L1 tensor + auto l1ATensor = l1ATensorList[l1ListIdNext]; + auto l1BTensor = l1BTensorList[l1ListIdNext]; + + // Get GM tensor for next stage + nActualNext = (firstTileIdxNext < nTileCountNext - 1) + ? L1TileShape::N + : (actualShapeNext.n() - firstTileIdxNext * L1TileShape::N); + nRoundNext = RoundUp(nActualNext); + + // Get GM tile + MatrixCoord gmTileAOffset{0, firstTileIdxNext * L1TileShape::N}; + uint32_t gmTilexOffset{firstTileIdxNext * L1TileShape::N}; + + auto gmTileA = gmNextBlockA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTilex = gmNextBlockX[gmTilexOffset]; + + // load vector x tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListIdNext]); + + auto layoutTilex = layoutX.GetTileLayout(MakeCoord(nRoundNext)); + + copyGmToL1A(l1ATensor, gmTilex, layoutXInL1, layoutTilex); + AscendC::SetFlag(l1AEventList[l1ListIdNext]); + + // load Matrix A tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListIdNext]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShapeNext.m(), nRoundNext)); + + copyGmToL1B(l1BTensor, gmTileA, layoutAInL1, layoutTileA); + AscendC::SetFlag(l1BEventList[l1ListIdNext]); + } + + // get L1 Tensor for current stage + auto l1ATensor = l1ATensorList[l1ListId]; + auto l1BTensor = l1BTensorList[l1ListId]; + + AscendC::WaitFlag(l1AEventList[l1ListId]); + AscendC::WaitFlag(l1BEventList[l1ListId]); + + uint32_t nRound = RoundUp(nActual); + uint32_t nPartLoop = CeilDiv(nActual); + + for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; nPartIdx++) { + uint32_t nPartActual = + (nPartIdx < nPartLoop - 1) ? L0TileShape::N : (nActual - nPartIdx * L0TileShape::N); + + // Locate the current tile on L0A + auto l0ATile = l0ATensorList[l0AListId]; + LayoutXInL0 layoutxInL0 = + LayoutXInL0::template MakeLayout(L1XAlignHelper::M_ALIGNED, nPartActual); + + MatrixCoord l1xOffset{0, nPartIdx * L0TileShape::N}; + auto l1ATile = l1ATensor[layoutXInL1.GetOffset(l1xOffset)]; + + AscendC::WaitFlag(l0AEventList[l0AListId]); + // Load current tile from L1 to L0A + copyL1ToL0A(l0ATile, l1ATile, layoutxInL0, layoutXInL1); + AscendC::SetFlag(l0AEventList[l0AListId]); + + // Locate the current tile on L0B + auto l0BTile = l0BTensorList[l0BListId]; + LayoutAInL0 layoutAInL0 = LayoutAInL0::template MakeLayout(L0TileShape::M, nPartActual); + + MatrixCoord l1AOffset{0, nPartIdx * L0TileShape::N}; + auto l1BTile = l1BTensor[layoutAInL1.GetOffset(l1AOffset)]; + + AscendC::WaitFlag(l0BEventList[l0BListId]); + // Load current tile from L1 to L0B + copyL1ToL0B(l0BTile, l1BTile, layoutAInL0, layoutAInL1); + AscendC::SetFlag(l0BEventList[l0BListId]); + + auto l0CTile = l0CTensor[(singleIdx % L0C_TILE_NUM) * L0C_TILE_SIZE]; + + // If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0 + bool initC = ((nLoopIdx == 0) && (nPartIdx == 0)); + + AscendC::WaitFlag(l0BEventList[l0BListId]); + AscendC::WaitFlag(l0AEventList[l0AListId]); + tileMmad(l0CTile, l0ATile, l0BTile, L1XAlignHelper::M_ALIGNED, L0TileShape::M, nPartActual, initC); + AscendC::SetFlag(l0AEventList[l0AListId]); + AscendC::SetFlag(l0BEventList[l0BListId]); + + l0AListId = (l0AListId + 1) % STAGES; + l0BListId = (l0BListId + 1) % STAGES; + } + + AscendC::SetFlag(l1AEventList[l1ListId]); + AscendC::SetFlag(l1BEventList[l1ListId]); + + l1ListId = l1ListIdNext; + + nActual = nActualNext; + } + + auto l0CTile = l0CTensor[(singleIdx % L0C_TILE_NUM) * L0C_TILE_SIZE]; + + // copy block out + LayoutY layoutBlock = layoutY.GetTileLayout(MakeCoord(uint32_t(1), actualShape.m())); + + AscendC::SetFlag((int32_t)(singleIdx % L0C_TILE_NUM)); + AscendC::WaitFlag((int32_t)(singleIdx % L0C_TILE_NUM)); + + copyL0CToGm(gmBlockY, l0CTile, layoutBlock, layoutInL0C); + } + +protected: + // Multi-stage tensors list + AscendC::LocalTensor l1ATensorList[STAGES]; + AscendC::LocalTensor l1BTensorList[STAGES]; + AscendC::LocalTensor l0ATensorList[STAGES]; + AscendC::LocalTensor l0BTensorList[STAGES]; + AscendC::LocalTensor l0CTensor; + + // Multi-stage event id list + int32_t l1AEventList[STAGES]; + int32_t l1BEventList[STAGES]; + int32_t l0AEventList[STAGES]; + int32_t l0BEventList[STAGES]; + + // The id of current stage + uint32_t l1ListId{0}; + + uint32_t l0AListId{0}; + uint32_t l0BListId{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +} // namespace Catlass::Gemv::Block + +#endif // CATLASS_GEMV_BLOCK_BLOCK_AIC_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/block/block_gemv_aiv.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/block/block_gemv_aiv.hpp new file mode 100644 index 00000000..4694b16b --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/block/block_gemv_aiv.hpp @@ -0,0 +1,207 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMV_BLOCK_BLOCK_GEMV_AIV_HPP +#define CATLASS_GEMV_BLOCK_BLOCK_GEMV_AIV_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/gemv_coord.hpp" +#include "catlass/gemv/helper.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/detail/alignment.hpp" +#include "catlass/gemm/dispatch_policy.hpp" + +namespace Catlass::Gemv::Block { + +template +struct BlockGemv { +public: + // Type Aliases + using DispatchPolicy = Gemm::GemvAtlasA2; + using ArchTag = typename DispatchPolicy::ArchTag; + using UBTileShape = UBTileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementX = typename XType_::Element; + using LayoutX = typename XType_::Layout; + using ElementY = typename YType_::Element; + using LayoutY = typename YType_::Layout; + using TileVmad = TileVmad_; + using TileVmuls = TileVmuls_; + using VecCopyGmToUb = typename TileCopy_::VecCopyGmToUb; + using VecCopyUbToGm = typename TileCopy_::VecCopyUbToGm; + using MatrixCopyGmToUb = typename TileCopy_::MatrixCopyGmToUb; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + using UBAlignHelper = Gemv::helper::UBAlignHelper; + using TensorCoord = layout::VectorLayout::TensorCoord; + static constexpr uint32_t STAGES = DispatchPolicy::STAGES; + static constexpr uint32_t Abuf_SIZE_ = 128 * 1024; + static constexpr uint32_t Xbuf_SIZE_ = 16 * 1024; + static constexpr uint32_t Ybuf_SIZE_ = 16 * 1024; + static constexpr uint32_t workspace_SIZE_ = 32 * 1024; + + CATLASS_DEVICE + BlockGemv() {} + + /// Construct + CATLASS_DEVICE + BlockGemv(Arch::Resource &resource, uint32_t UBufAddrStart = 0) + { + uint32_t UbAOffset = UBufAddrStart; + uint32_t UbXOffset = UBufAddrStart + Abuf_SIZE_; + uint32_t UbYOffset = UBufAddrStart + Abuf_SIZE_ + Xbuf_SIZE_; + uint32_t UbWOffset = UBufAddrStart + Abuf_SIZE_ + Xbuf_SIZE_ + Ybuf_SIZE_; + // Init buffers + for (uint32_t i = 0; i < STAGES; i++) { + // Assign L1/L0A/L0B space for each stages + UbATensorList[i] = resource.ubBuf.template GetBufferByByte(UbAOffset + i * (Abuf_SIZE_ / 2)); + UbXTensorList[i] = resource.ubBuf.template GetBufferByByte(UbXOffset + i * (Xbuf_SIZE_ / 2)); + UbYTensorList[i] = resource.ubBuf.template GetBufferByByte(UbYOffset + i * (Ybuf_SIZE_ / 2)); + UbWTensorList[i] = + resource.ubBuf.template GetBufferByByte(UbWOffset + i * (workspace_SIZE_ / 2)); + + // Assign event ID for each stages + UbInAEventList[i] = i; + UbInXEventList[i] = i + STAGES; + UbOutEventList[i] = i; + + // The event id that needs to be set before the loop + AscendC::SetFlag(UbInAEventList[i]); + AscendC::SetFlag(UbInXEventList[i]); + AscendC::SetFlag(UbOutEventList[i]); + } + } + + /// Destructor + CATLASS_DEVICE + ~BlockGemv() + { + for (uint32_t i = 0; i < STAGES; i++) { + AscendC::WaitFlag(UbInAEventList[i]); + AscendC::WaitFlag(UbInXEventList[i]); + AscendC::WaitFlag(UbOutEventList[i]); + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &gmA, LayoutA const &layoutA, + AscendC::GlobalTensor const &gmX, LayoutX const &layoutX, + AscendC::GlobalTensor const &gmY, LayoutY const &layoutY, + AscendC::GlobalTensor const &gmZ, GemvCoord const &actualShape, float alpha, float beta) + { + AscendC::WaitFlag((event_t)(UbOutEventList[UbOutListId])); + vecCopyGmToUb(UbYTensorList[UbOutListId], gmY, actualShape.m()); + AscendC::SetFlag((event_t)(UbOutEventList[UbOutListId])); + AscendC::WaitFlag((event_t)(UbOutEventList[UbOutListId])); + tileVmuls(UbYTensorList[UbOutListId], UbYTensorList[UbOutListId], (ElementY)beta, actualShape.m()); + AscendC::SetFlag((event_t)(UbOutEventList[UbOutListId])); + AscendC::WaitFlag((event_t)(UbOutEventList[UbOutListId])); + + TileMRound = RoundUp(UBTileShape::M, UBAlignHelper::ALIGN); + TileNRound = RoundUp(UBTileShape::N, UBAlignHelper::ALIGN); + strideA = layoutA.stride(1) * TileNRound; + m_actual = (actualShape.m() < TileMRound) ? actualShape.m() : TileMRound; + n_actual = (actualShape.n() < TileNRound) ? actualShape.n() : TileNRound; + AscendC::WaitFlag((event_t)(UbInXEventList[UbInListId])); + vecCopyGmToUb(UbXTensorList[UbInListId], gmX, n_actual); + AscendC::SetFlag((event_t)(UbInXEventList[UbInListId])); + + AscendC::WaitFlag((event_t)(UbInAEventList[UbInListId])); + auto layoutAInUb = layoutA.GetTileLayout(MakeCoord(TileMRound, TileNRound)); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(m_actual, n_actual)); + matrixCopyGmToUb(UbATensorList[UbInListId], gmA, layoutAInUb, layoutTileA); + AscendC::SetFlag((event_t)(UbInAEventList[UbInListId])); + // main loop + uint32_t Nloop = CeilDiv(actualShape.n(), TileNRound); + for (uint32_t LoopIdx = 0; LoopIdx < Nloop; LoopIdx++) { + m_actual = (actualShape.m() < TileMRound) ? actualShape.m() : TileMRound; + n_actual = (LoopIdx == Nloop - 1) ? (actualShape.n() - LoopIdx * TileNRound) : TileNRound; + y_actual = m_actual; + x_actual = n_actual; + + uint32_t UbInListIdNext = (UbInListId + 1 < STAGES) ? (UbInListId + 1) : 0; + if (LoopIdx < Nloop - 1) { + uint32_t LoopIdxNext = LoopIdx + 1; + uint32_t m_actual_next = m_actual; + uint32_t n_actual_next = + (LoopIdxNext == Nloop - 1) ? (actualShape.n() - LoopIdxNext * TileNRound) : TileNRound; + uint32_t y_actual_next = m_actual_next; + uint32_t x_actual_next = n_actual_next; + // Get L1 tensor for next stage + auto matrixTensor = UbATensorList[UbInListIdNext]; + auto vecTensor = UbXTensorList[UbInListIdNext]; + + AscendC::WaitFlag((event_t)(UbInXEventList[UbInListIdNext])); + vecCopyGmToUb(vecTensor, gmX[LoopIdxNext * TileNRound], x_actual_next); + AscendC::SetFlag((event_t)(UbInXEventList[UbInListIdNext])); + AscendC::WaitFlag((event_t)(UbInAEventList[UbInListIdNext])); + auto layoutAInUb = layoutA.GetTileLayout(MakeCoord(TileMRound, TileNRound)); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(m_actual_next, n_actual_next)); + matrixCopyGmToUb(matrixTensor, gmA[LoopIdxNext * strideA], layoutAInUb, layoutTileA); + AscendC::SetFlag((event_t)(UbInAEventList[UbInListIdNext])); + } + AscendC::WaitFlag((event_t)(UbInXEventList[UbInListId])); + tileVmuls(UbXTensorList[UbInListId], UbXTensorList[UbInListId], (ElementA)alpha, x_actual); + AscendC::PipeBarrier(); + + AscendC::WaitFlag((event_t)(UbInAEventList[UbInListId])); + auto layoutComputeInUb = layoutA.GetTileLayout(MakeCoord(TileMRound, TileNRound)); + auto layoutTileCompute = layoutA.GetTileLayout(MakeCoord(m_actual, n_actual)); + tileVmad(UbYTensorList[UbOutListId], UbXTensorList[UbInListId], UbATensorList[UbInListId], + UbWTensorList[UbInListId], layoutComputeInUb, layoutTileCompute); + AscendC::SetFlag((event_t)(UbInAEventList[UbInListId])); + AscendC::SetFlag((event_t)(UbInXEventList[UbInListId])); + UbInListId = UbInListIdNext; + } + AscendC::SetFlag((event_t)(UbOutEventList[UbOutListId])); + AscendC::WaitFlag((event_t)(UbOutEventList[UbOutListId])); + auto layoutDstY = layoutY.GetTileLayout(TensorCoord(y_actual)); + auto layoutComputeInUb = layoutY.GetTileLayout(TensorCoord(y_actual)); + vecCopyUbToGm(gmZ, UbYTensorList[UbOutListId], layoutDstY, layoutComputeInUb); + AscendC::SetFlag((event_t)(UbOutEventList[UbOutListId])); + UbOutListId = (UbOutListId + 1 < STAGES) ? (UbOutListId + 1) : 0; + } + +protected: + // Multi-stage tensors list + AscendC::LocalTensor UbATensorList[STAGES]; + AscendC::LocalTensor UbXTensorList[STAGES]; + AscendC::LocalTensor UbYTensorList[STAGES]; + AscendC::LocalTensor UbWTensorList[STAGES]; + + // Multi-stage event id list + int32_t UbInAEventList[STAGES]; + int32_t UbInXEventList[STAGES]; + int32_t UbOutEventList[STAGES]; + + // The id of current stage + uint32_t UbOutListId{0}; + uint32_t UbInListId{0}; + + uint32_t m_actual, n_actual, x_actual, y_actual; + uint32_t TileMRound, TileNRound; + uint32_t strideA; + + TileVmad tileVmad; + TileVmuls tileVmuls; + MatrixCopyGmToUb matrixCopyGmToUb; + VecCopyGmToUb vecCopyGmToUb; + VecCopyUbToGm vecCopyUbToGm; +}; + +} // namespace Catlass::Gemv::Block + +#endif // CATLASS_GEMV_BLOCK_BLOCK_GEMV_AIV_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/device/device_gemv.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/device/device_gemv.hpp new file mode 100644 index 00000000..54e07282 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/device/device_gemv.hpp @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMV_DEVICE_GEMV_UNIVERSAL_ADAPTER_HPP +#define CATLASS_GEMV_DEVICE_GEMV_UNIVERSAL_ADAPTER_HPP + +#include +#include "catlass/catlass.hpp" +#include "catlass/status.hpp" +#include "catlass/gemv/device/kernel_adapter.hpp" + +#if defined(ENABLE_ASCENDC_DUMP) +#include "catlass/debug.hpp" +#endif + +namespace Catlass::Gemv::Device { + +template +class DeviceGemv +{ +public: + /// Argument structure: User API + using Arguments = typename GemvKernel::Arguments; + /// Argument structure: Kernel API + using Params = typename GemvKernel::Params; + +private: + /// kernel API parameters object + Params params_; + +public: + DeviceGemv() {} + ~DeviceGemv() {} + + /// Access the Params structure + Params const ¶ms() const + { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status CanImplement(Arguments const &args) + { + if (GemvKernel::CanImplement(args)) { + return Status::kSuccess; + } else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t GetWorkspaceSize(Arguments const &args) + { + size_t workspace_bytes = 0; + workspace_bytes += GemvKernel::GetWorkspaceSize(args); + return workspace_bytes; + } + + /// Initializes GEMV state from arguments + Status Initialize(Arguments const &args, uint8_t *workspace = nullptr, aclrtStream stream = nullptr) + { + // Initialize the Params structure + params_ = GemvKernel::ToUnderlyingArguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling matmul Kernel::to_underling arguments + inline Status Run(aclrtStream stream, uint32_t blockDim, uint64_t fftsAddr) + { +#if defined(ENABLE_ASCENDC_DUMP) + uint8_t *ptrDump{nullptr}; + aclCheck(aclrtMalloc(reinterpret_cast(&ptrDump), ALL_DUMPSIZE, ACL_MEM_MALLOC_HUGE_FIRST)); + if (fftsAddr == 0) { + Catlass::KernelAdapter<<>>(params_, ptrDump); + } else { + Catlass::KernelAdapter<<>>(params_, fftsAddr, ptrDump); + } + aclCheck(aclrtSynchronizeStream(stream)); + Adx::AdumpPrintWorkSpace(ptrDump, ALL_DUMPSIZE, stream, "device_gemm"); + aclCheck(aclrtFree(ptrDump)); +#else + if (fftsAddr == 0) { + Catlass::KernelAdapter<<>>(params_); + } else { + Catlass::KernelAdapter<<>>(params_, fftsAddr); + } +#endif + return Status::kSuccess; + } + + /// Runs the kernel using initialized state + inline Status operator()(aclrtStream stream, uint32_t blockDim) + { + return Run(stream, blockDim, 0); + } + + inline Status operator()(aclrtStream stream, uint32_t blockDim, uint64_t fftsAddr) + { + return Run(stream, blockDim, fftsAddr); + } +}; +/////////////////////////////////////////////////////////////////////////////////// + +} // namespace Catlass::Gemv::Device +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/device/kernel_adapter.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/device/kernel_adapter.hpp new file mode 100644 index 00000000..7c414ed3 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/device/kernel_adapter.hpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef CATLASS_GEMV_DEVICE_KERNEL_ADAPTER_HPP +#define CATLASS_GEMV_DEVICE_KERNEL_ADAPTER_HPP + +#include "catlass/catlass.hpp" + +#if defined(ENABLE_ASCENDC_DUMP) +#include "catlass/debug.hpp" +#endif + +namespace Catlass { +/// Generic Catlass kernel template +template +CATLASS_GLOBAL void KernelAdapter(typename Operator::Params params, GM_ADDR ptrDump = nullptr) +{ + Operator op; +#if defined(ENABLE_ASCENDC_DUMP) + AscendC::InitDump(false, ptrDump, ALL_DUMPSIZE); +#endif + op(params); +} + +template +CATLASS_GLOBAL void KernelAdapter(typename Operator::Params params, uint64_t fftsAddr, GM_ADDR ptrDump = nullptr) +{ + AscendC::SetSyncBaseAddr(fftsAddr); + Operator op; +#if defined(ENABLE_ASCENDC_DUMP) + AscendC::InitDump(false, ptrDump, ALL_DUMPSIZE); +#endif + op(params); +} +} // namespace Catlass + +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/helper.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/helper.hpp new file mode 100644 index 00000000..20cea570 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/helper.hpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMV_HELPER_HPP +#define CATLASS_GEMV_HELPER_HPP + +#include "catlass/catlass.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/gemm/gemm_type.hpp" + +namespace Catlass::Gemv::helper { + +template +struct UBAlignHelper { + static constexpr uint32_t ALIGN = BYTE_PER_BLK / sizeof(Element); +}; + +template +struct AtomicAddSelector { + static_assert(DEPENDENT_FALSE, "Unsupported layout selector, can not find the specialization."); +}; + +template +struct AtomicAddSelector> { + static constexpr bool value = false; +}; + +template +struct AtomicAddSelector> { + static constexpr bool value = true; +}; + +template +struct L1AlignHelper { + static_assert(DEPENDENT_FALSE, "Unsupported align helper, can not find the specialization."); +}; + +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; +}; + +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + static constexpr uint32_t getNAligned() + { + if constexpr (std::is_same::value) { + return ELE_NUM_PER_C0 / sizeof(Element); + } else { + return C0_NUM_PER_FRACTAL; + } + } + + static constexpr uint32_t getMAligned() + { + if constexpr (std::is_same::value) { + return ELE_NUM_PER_C0 / sizeof(Element); + } else { + return C0_NUM_PER_FRACTAL; + } + } + + static constexpr uint32_t N_ALIGNED = getNAligned(); + static constexpr uint32_t M_ALIGNED = getMAligned(); +}; + +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; +}; + +//////////////////////////////// +// new add gemvaic selector +template +struct L1AndL0TypeSelectorGemv { + static_assert(DEPENDENT_FALSE, "Unsupported layout selector, can not find the specialization."); + static_assert(DEPENDENT_FALSE, "Unsupported layout selector, can not find the specialization."); +}; + +template +struct L1AndL0TypeSelectorGemv, + Gemm::GemmType> { + using L1AType = Gemm::GemmType; + using L1BType = Gemm::GemmType; + using L0AType = Gemm::GemmType; + using L0BType = Gemm::GemmType; +}; + +template +struct L1AndL0TypeSelectorGemv, + Gemm::GemmType> { + using L1AType = Gemm::GemmType; + using L1BType = Gemm::GemmType; + using L0AType = Gemm::GemmType; + using L0BType = Gemm::GemmType; +}; + +template <> +struct L1AndL0TypeSelectorGemv, + Gemm::GemmType> { + using L1AType = Gemm::GemmType; + using L1BType = Gemm::GemmType; + using L0AType = Gemm::GemmType; + using L0BType = Gemm::GemmType; +}; + +} // namespace Catlass::Gemv::helper + +#endif // CATLASS_GEMV_HELPER_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/kernel/kernel_gemv_aic.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/kernel/kernel_gemv_aic.hpp new file mode 100644 index 00000000..dadb5109 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/kernel/kernel_gemv_aic.hpp @@ -0,0 +1,284 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef CATLASS_GEMV_KERNLE_GEMV_AIC_HPP +#define CATLASS_GEMV_KERNLE_GEMV_AIC_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/tile/copy_gm_to_ub.hpp" +#include "catlass/epilogue/tile/copy_ub_to_gm.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/gemv_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemv::Kernel { + +// tmeplate for gemv kernle, Compute z = αAx + βy +template +class KernelGemvAic +{ +public: + using BlockGemv = BlockGemv_; + using ArchTag = typename BlockGemv::ArchTag; + using L1TileShape = typename BlockGemv::L1TileShape; + using L0TileShape = typename BlockGemv::L0TileShape; + + using ElementX = typename BlockGemv::ElementX; + using LayoutX = typename BlockGemv::LayoutX; + + using ElementA = typename BlockGemv::ElementA; + using LayoutA = typename BlockGemv::LayoutA; + using ElementY = typename BlockGemv::ElementY; + using LayoutY = typename BlockGemv::LayoutY; + + using BlockEpilogue = BlockEpilogue_; + using ElementZ = typename BlockEpilogue::ElementZ; + using LayoutZ = typename BlockEpilogue::LayoutZ; + using EpilogueParams = typename BlockEpilogue::Params; + + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + struct Params { + // Data members + GemvCoord problemShape; + GM_ADDR ptrX; + LayoutX layoutX; + GM_ADDR ptrA; + LayoutA layoutA; + GM_ADDR ptrWorkspace; + EpilogueParams epilogueParams; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemvCoord const &problemShape_, GM_ADDR ptrX_, LayoutX layoutX_, GM_ADDR ptrA_, LayoutA layoutA_, + GM_ADDR ptrWorkspace_, EpilogueParams const &epilogueParams_) + : problemShape(problemShape_), + ptrX(ptrX_), + layoutX(layoutX_), + ptrA(ptrA_), + layoutA(layoutA_), + ptrWorkspace(ptrWorkspace_), + epilogueParams(epilogueParams_) + {} + }; + + struct Arguments { + GemvCoord problemShape; + ElementY alpha; + ElementY beta; + size_t elementSize; + GM_ADDR ptrX; + GM_ADDR ptrA; + GM_ADDR ptrZ; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + return args.elementSize * args.problemShape.m(); + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + GemvCoord problemShape = args.problemShape; + uint32_t m = problemShape.m(); + uint32_t n = problemShape.n(); + LayoutX layoutX{n}; + LayoutA layoutA{m, n}; + LayoutZ layoutZ{m}; + typename BlockEpilogue::Params epilogueParams{args.alpha, args.beta, args.ptrZ, layoutZ, args.ptrZ, layoutZ}; + + Params params{problemShape, args.ptrX, layoutX, args.ptrA, layoutA, workspace, epilogueParams}; + return params; + } + + // Methods + CATLASS_DEVICE + KernelGemvAic() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockGemv blockGemv(resource); + // Represent the full gm + AscendC::GlobalTensor gmX; + gmX.SetGlobalBuffer((__gm__ ElementX *)params.ptrX); + + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrA); + + AscendC::GlobalTensor gmY; + gmY.SetGlobalBuffer((__gm__ ElementY *)params.ptrWorkspace); + + layout::RowMajor layoutY(1, params.problemShape.m()); + + uint32_t maxMPerBlock = L1TileShape::M; + uint32_t maxNPerBlock = L1TileShape::N; + uint32_t M = params.problemShape.m(); + uint32_t N = params.problemShape.n(); + + uint32_t MLoops = CeilDiv(M, maxMPerBlock); + uint32_t coreLoops = MLoops; + uint32_t singleIdx = 0; + + static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; + static constexpr uint32_t L0C_TILE_SIZE = L0TileShape::M * L0TileShape::N; + static constexpr uint32_t L0C_TILE_NUM = L0C_SIZE / L0C_TILE_SIZE / sizeof(ElementAccumulator); + +#pragma unroll + for (uint32_t i = 0; i < L0C_TILE_NUM; i++) { + AscendC::SetFlag((event_t)i); + } + + for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { + // Compute Block location + uint32_t MGmBlockIdx = loopIdx; + uint32_t MGmActual = (MGmBlockIdx == MLoops - 1) ? (M - MGmBlockIdx * maxMPerBlock) : maxMPerBlock; + uint32_t NGmActual = N; + int64_t gmOffsetX; + int64_t gmOffsetA; + int64_t gmOffsetY; + int64_t gmOffsetNextX; + int64_t gmOffsetNextA; + int64_t gmOffsetNextY; + + if constexpr (std::is_same::value) { + gmOffsetX = 0; + gmOffsetA = MGmBlockIdx * maxMPerBlock * params.layoutA.stride(0); + + gmOffsetY = MGmBlockIdx * maxMPerBlock; + } else { + gmOffsetX = 0; + gmOffsetA = MGmBlockIdx * maxMPerBlock; + gmOffsetY = MGmBlockIdx * maxMPerBlock; + } + + bool isFirstBlock = (loopIdx == AscendC::GetBlockIdx()); + bool hasNextBlock = false; + uint32_t MNextGmBlockIdx; + GemvCoord nextActualBlockShape; + if (loopIdx + AscendC::GetBlockNum() < coreLoops) { + hasNextBlock = true; + uint32_t nextLoopIdx = loopIdx + AscendC::GetBlockNum(); + MNextGmBlockIdx = nextLoopIdx; + uint32_t MNextGmActual = + (MNextGmBlockIdx == MLoops - 1) ? (M - MNextGmBlockIdx * maxMPerBlock) : maxMPerBlock; + uint32_t NNextGmActual = N; + nextActualBlockShape = GemvCoord(MNextGmActual, NNextGmActual); + } + + if constexpr (std::is_same::value) { + gmOffsetNextX = 0; + gmOffsetNextA = MNextGmBlockIdx * maxMPerBlock * params.layoutA.stride(0); + + gmOffsetNextY = MNextGmBlockIdx * maxMPerBlock; + } else { + gmOffsetNextX = 0; + gmOffsetNextA = MNextGmBlockIdx * maxMPerBlock; + gmOffsetNextY = MNextGmBlockIdx * maxMPerBlock; + } + + GemvCoord actualBlockShape = GemvCoord(MGmActual, NGmActual); + + AscendC::WaitFlag((event_t)singleIdx % L0C_TILE_NUM); + + // Compute block-scoped matrix multiply-add + blockGemv(gmX[gmOffsetX], params.layoutX, gmA[gmOffsetA], params.layoutA, gmY[gmOffsetY], layoutY, + gmX[gmOffsetNextX], gmA[gmOffsetNextA], actualBlockShape, nextActualBlockShape, isFirstBlock, + hasNextBlock, singleIdx); + + Arch::CrossCoreSetFlagWithReverse<0x2, PIPE_FIX>(flagAicFinishStore); + AscendC::SetFlag((event_t)singleIdx % L0C_TILE_NUM); + + singleIdx++; + } + +#pragma unroll + for (uint32_t i = 0; i < L0C_TILE_NUM; i++) { + AscendC::WaitFlag((event_t)i); + } + + AscendC::PipeBarrier(); + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockEpilogue blockEpilogue(resource, params.epilogueParams); + + // Represent the full gm + AscendC::GlobalTensor gmY; + gmY.SetGlobalBuffer((__gm__ ElementY *)params.ptrWorkspace); + + layout::VectorLayout layoutY(params.problemShape.m()); + + // Get aicore information + uint32_t aicoreIndex = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t aicoreNum = AscendC::GetBlockNum(); + uint32_t subcoreIndex = AscendC::GetSubBlockIdx(); + + uint32_t maxMPerBlock = L1TileShape::M; + uint32_t M = params.problemShape.m(); + uint32_t MLoops = CeilDiv(M, maxMPerBlock); + uint32_t coreLoops = MLoops; + + // Loop through the epilogue calculations of each basic block + layout::VectorLayout::TensorCoord blockShape{L1TileShape::M}; + + for (uint32_t loopIdx = aicoreIndex; loopIdx < coreLoops; loopIdx += aicoreNum) { + // Compute block location + layout::VectorLayout::TensorCoord blockCoord{loopIdx}; + uint32_t MGmActual = (loopIdx == coreLoops) ? M - loopIdx * maxMPerBlock : maxMPerBlock; + + layout::VectorLayout::TensorCoord actualBlockShape{MGmActual}; + + // Get the offset + layout::VectorLayout::TensorCoord blockOffset = blockCoord * blockShape; + + // Get the data and layout of y under the current basic block + auto gmBlockY = gmY[layoutY.GetOffset(blockOffset)]; + auto layoutBlockY = layoutY.GetTileLayout(actualBlockShape); + + // Synchronize cross core + Arch::CrossCoreWaitFlagWithReverse<0x2, PIPE_MTE3>(flagAicFinishStore); + + // Actual calculatioin logic for performing block-scoped epilogue + blockEpilogue(blockOffset, actualBlockShape, gmBlockY, layoutBlockY); + } + + AscendC::PipeBarrier(); + } + +private: + // ID used for inter-core synchronization + static constexpr Arch::FlagID FLAG_AIC_FINISH_STORE = 0; + static constexpr Arch::FlagID RV_FLAG_AIC_FINISH_STORE = 1; + Arch::CrossCoreFlagWithReverse<> flagAicFinishStore{FLAG_AIC_FINISH_STORE, RV_FLAG_AIC_FINISH_STORE}; + Arch::Resource resource; + + static constexpr Arch::FlagID FLAG_AIV_FINISH_STORE = 0; + Arch::CrossCoreFlag flagAivFinishPadding{FLAG_AIV_FINISH_STORE}; +}; + +} // namespace Catlass::Gemv::Kernel + +#endif // CATLASS_GEMV_KERNLE_GEMV_EPILOGUE_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/kernel/kernel_gemv_aiv.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/kernel/kernel_gemv_aiv.hpp new file mode 100644 index 00000000..c1f88615 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/kernel/kernel_gemv_aiv.hpp @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMV_KERNLE_GEMV_AIV_HPP +#define CATLASS_GEMV_KERNLE_GEMV_AIV_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/gemv_coord.hpp" + +namespace Catlass::Gemv::Kernel { + +// tmeplate for gemv kernle, Compute z = αAx + βy +template +class KernelGemvAiv +{ +public: + using BlockGemv = BlockGemv_; + using ArchTag = typename BlockGemv::ArchTag; + using UBTileShape = typename BlockGemv::UBTileShape; + using ElementA = typename BlockGemv::ElementA; + using LayoutA = typename BlockGemv::LayoutA; + using ElementX = typename BlockGemv::ElementX; + using LayoutX = typename BlockGemv::LayoutX; + using ElementY = typename BlockGemv::ElementY; + using LayoutY = typename BlockGemv::LayoutY; + using ElementAccumulator = typename BlockGemv::ElementAccumulator; + + /// Parameters structure + struct Params { + // Data members + GemvCoord problemShape; + GM_ADDR ptrA; + LayoutA layoutA; + GM_ADDR ptrX; + LayoutX layoutX; + GM_ADDR ptrY; + LayoutY layoutY; + GM_ADDR ptrZ; + float alpha; + float beta; + uint32_t split; + + // Methods + CATLASS_HOST_DEVICE + Params() {} + + CATLASS_HOST_DEVICE + Params(GemvCoord const &problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrX_, LayoutX layoutX_, + GM_ADDR ptrY_, LayoutY layoutY_, GM_ADDR ptrZ_, float alpha_, float beta_, uint32_t split_) + : problemShape(problemShape_), + ptrA(ptrA_), + layoutA(layoutA_), + ptrX(ptrX_), + layoutX(layoutX_), + ptrY(ptrY_), + layoutY(layoutY_), + ptrZ(ptrZ_), + alpha(alpha_), + beta(beta_), + split(split_) + {} + }; + + // TODO: add arguments + struct Arguments { + GemvCoord problemShape; + GM_ADDR ptrA; + GM_ADDR ptrX; + GM_ADDR ptrY; + GM_ADDR ptrZ; + float alpha; + float beta; + uint32_t split; + }; + + static bool CanImplement(const Arguments &args) + { + return true; + } + + static size_t GetWorkspaceSize(const Arguments &args) + { + return sizeof(ElementY) * args.problemShape.m(); + } + + static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) + { + GemvCoord problemShape = args.problemShape; + uint32_t m = problemShape.m(); + uint32_t n = problemShape.n(); + LayoutA layoutA{m, n}; + LayoutX layoutX{n}; + LayoutY layoutY{m}; + Params params{problemShape, args.ptrA, layoutA, args.ptrX, layoutX, args.ptrY, + layoutY, args.ptrZ, args.alpha, args.beta, args.split}; + return params; + } + + // Methods + CATLASS_DEVICE + KernelGemvAiv() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms) {}; + + /// Executes one Matmul + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + {} + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + AscendC::SetAtomicNone(); + Arch::Resource resource; + BlockGemv blockGemv(resource); + uint32_t align = BYTE_PER_C0 / sizeof(ElementA); + uint32_t maxmPerBlock_round = RoundUp(UBTileShape::M, align); + uint32_t maxnPerBlock_round = RoundUp(UBTileShape::N, align); + + // add split k + uint32_t N_Split = RoundDown(params.problemShape.n(), params.split) / params.split; + uint32_t Mloopnum = CeilDiv(params.problemShape.m(), maxmPerBlock_round); + int32_t loopnum; + float Realbeta = params.beta; + if constexpr (std::is_same_v) { + loopnum = Mloopnum * params.split; + Realbeta = params.beta - 1.0f; + } else { + loopnum = Mloopnum; + } + + uint32_t offset_matrix; + uint32_t offset_vector_out; + uint32_t offset_vector_in = 0; + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrA); + AscendC::GlobalTensor gmX; + gmX.SetGlobalBuffer((__gm__ ElementX *)params.ptrX); + AscendC::GlobalTensor gmY; + gmY.SetGlobalBuffer((__gm__ ElementY *)params.ptrY); + AscendC::GlobalTensor gmZ; + gmZ.SetGlobalBuffer((__gm__ ElementY *)params.ptrZ); + uint32_t aiv_num = AscendC::GetBlockNum() * AscendC::GetTaskRation(); + for (uint32_t loop_id = 0; loop_id < loopnum; loop_id++) { + uint32_t aiv_id = AscendC::GetBlockIdx(); + if (loop_id % aiv_num != aiv_id) continue; + uint32_t m_actual = ((int32_t)loop_id > (int32_t)(loopnum - params.split - 1)) + ? params.problemShape.m() - ((loop_id / params.split) * maxmPerBlock_round) + : maxmPerBlock_round; + uint32_t n_actual = params.problemShape.n(); + + if constexpr (std::is_same_v) { + offset_matrix = (loop_id % params.split) * N_Split * params.problemShape.m() + + (loop_id / params.split) * maxmPerBlock_round; + offset_vector_out = (loop_id / params.split) * maxmPerBlock_round; + offset_vector_in = (loop_id % params.split) * N_Split; + + if ((loop_id % params.split) == params.split - 1) { + n_actual = params.problemShape.n() - N_Split * (params.split - 1); + } else { + n_actual = N_Split; + } + } else { + offset_matrix = loop_id * maxmPerBlock_round * params.problemShape.n(); + offset_vector_out = loop_id * maxmPerBlock_round; + } + GemvCoord actualBlockShape = GemvCoord{m_actual, n_actual}; + + float realbeta = (loop_id % params.split == 0) ? Realbeta : 0.0f; + + blockGemv(gmA[offset_matrix], params.layoutA, gmX[offset_vector_in], params.layoutX, gmY[offset_vector_out], + params.layoutY, gmZ[offset_vector_out], actualBlockShape, params.alpha, realbeta); + } + + AscendC::PipeBarrier(); + } +}; + +} // namespace Catlass::Gemv::Kernel + +#endif // CATLASS_GEMV_KERNLE_GEMV_AIV_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/tile/matrix_copy_gm_to_ub.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/tile/matrix_copy_gm_to_ub.hpp new file mode 100644 index 00000000..5c7a2217 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/tile/matrix_copy_gm_to_ub.hpp @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMV_TILE_MATRIX_COPY_GM_TO_UB_HPP +#define CATLASS_GEMV_TILE_MATRIX_COPY_GM_TO_UB_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/gemm/gemm_type.hpp" + +namespace Catlass::Gemv::Tile { + +template +struct MatrixCopyGmToUB { + static_assert(DEPENDENT_FALSE, "Unsupported copy gm to UB, can not find the specialization."); +}; + +/// Partial specialization for AtlasA2, RowMajor in and RowMajor out. +/// Matrix A confirm +template +struct MatrixCopyGmToUB> { + using LayoutDst = layout::RowMajor; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + MatrixCopyGmToUB() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::GlobalTensor srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + uint32_t m_actual = layoutSrc.shape(0); + uint32_t n_actual = layoutSrc.shape(1); + uint32_t m_round = layoutDst.shape(0); + uint32_t n_round = layoutDst.shape(1); + uint32_t stride = layoutSrc.stride(0); + + AscendC::DataCopyParams params; + if ((n_actual % ELE_NUM_PER_C0 == 0) && (stride % ELE_NUM_PER_C0 == 0) && (stride < STRIDE_LIMIT)) { + params.blockCount = m_actual; + params.blockLen = CeilDiv(n_actual, ELE_NUM_PER_C0); + params.srcStride = (stride - n_actual) / ELE_NUM_PER_C0; + params.dstStride = (n_round - n_actual) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, params); + } else if ((n_actual % ELE_NUM_PER_C0 == 0) && (stride * ELE_NUM_PER_C0 < STRIDE_LIMIT)) { + uint32_t counts = m_actual / ELE_NUM_PER_C0; + uint32_t remain = m_actual % ELE_NUM_PER_C0; + if (counts > 0) { + params.blockCount = counts; + params.blockLen = CeilDiv(n_actual, ELE_NUM_PER_C0); + params.srcStride = (ELE_NUM_PER_C0 * stride - n_actual) / ELE_NUM_PER_C0; + params.dstStride = (ELE_NUM_PER_C0 * n_round - n_actual) / ELE_NUM_PER_C0; + for (uint32_t i = 0; i < ELE_NUM_PER_C0; i++) { + AscendC::DataCopy(dstTensor[i * n_round], srcTensor[i * stride], params); + } + } + if (remain > 0) { + params.blockCount = 1; + params.blockLen = CeilDiv(n_actual, ELE_NUM_PER_C0); + params.srcStride = 0; + params.dstStride = 0; + for (uint32_t i = 0; i < remain; i++) { + AscendC::DataCopy(dstTensor[counts * n_round * ELE_NUM_PER_C0 + i * n_round], + srcTensor[counts * stride * ELE_NUM_PER_C0 + i * stride], params); + } + } + } else { + params.blockCount = 1; + params.blockLen = CeilDiv(n_actual, ELE_NUM_PER_C0); + params.srcStride = 0; + params.dstStride = 0; + for (uint32_t i = 0; i < m_actual; i++) { + AscendC::DataCopy(dstTensor[i * n_round], srcTensor[i * stride], params); + } + } + } +}; + +/// Partial specialization for AtlasA2, ColumnMajor in and ColumnMajor out. +/// Matrix A confirm +template +struct MatrixCopyGmToUB> { + using LayoutDst = layout::ColumnMajor; + using LayoutSrc = layout::ColumnMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + MatrixCopyGmToUB() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + uint32_t m_actual = layoutSrc.shape(0); + uint32_t n_actual = layoutSrc.shape(1); + uint32_t m_round = layoutDst.shape(0); + uint32_t n_round = layoutDst.shape(1); + uint32_t stride = layoutSrc.stride(1); + + AscendC::DataCopyParams params; + if ((m_actual % ELE_NUM_PER_C0 == 0) && (stride % ELE_NUM_PER_C0 == 0) && (stride < STRIDE_LIMIT)) { + params.blockCount = n_actual; + params.blockLen = CeilDiv(m_actual, ELE_NUM_PER_C0); + params.srcStride = (stride - m_actual) / ELE_NUM_PER_C0; + params.dstStride = (m_round - m_actual) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, params); + } else if ((m_actual % ELE_NUM_PER_C0 == 0) && (stride * ELE_NUM_PER_C0 < STRIDE_LIMIT)) { + uint32_t counts = n_actual / ELE_NUM_PER_C0; + uint32_t remain = n_actual % ELE_NUM_PER_C0; + if (counts > 0) { + params.blockCount = counts; + params.blockLen = CeilDiv(m_actual, ELE_NUM_PER_C0); + params.srcStride = (ELE_NUM_PER_C0 * stride - m_actual) / ELE_NUM_PER_C0; + params.dstStride = (ELE_NUM_PER_C0 * m_round - m_actual) / ELE_NUM_PER_C0; + for (uint32_t i = 0; i < ELE_NUM_PER_C0; i++) { + AscendC::DataCopy(dstTensor[i * m_round], srcTensor[i * stride], params); + } + } + if (remain > 0) { + params.blockCount = 1; + params.blockLen = CeilDiv(m_actual, ELE_NUM_PER_C0); + params.srcStride = 0; + params.dstStride = 0; + for (uint32_t i = 0; i < remain; i++) { + AscendC::DataCopy(dstTensor[counts * m_round * ELE_NUM_PER_C0 + i * m_round], + srcTensor[counts * stride * ELE_NUM_PER_C0 + i * stride], params); + } + } + } else { + params.blockCount = 1; + params.blockLen = CeilDiv(m_actual, ELE_NUM_PER_C0); + params.srcStride = 0; + params.dstStride = 0; + for (uint32_t i = 0; i < n_actual; i++) { + AscendC::DataCopy(dstTensor[i * m_round], srcTensor[i * stride], params); + } + } + } +}; + +} // namespace Catlass::Gemv::Tile + +#endif // CATLASS_GEMV_TILE_MATRIX_COPY_GM_TO_UB_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/tile/tile_copy.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/tile/tile_copy.hpp new file mode 100644 index 00000000..25f80831 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/tile/tile_copy.hpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMV_TILE_TILE_COPY_HPP +#define CATLASS_GEMV_TILE_TILE_COPY_HPP + +#include "catlass/catlass.hpp" +#include "catlass/detail/tag_to_layout.hpp" + +#include "catlass/gemv/tile/vec_copy_gm_to_ub.hpp" +#include "catlass/gemv/tile/vec_copy_ub_to_gm.hpp" +#include "catlass/gemv/tile/matrix_copy_gm_to_ub.hpp" + +#include "catlass/gemm/tile/copy_gm_to_l1.hpp" +#include "catlass/gemm/tile/copy_l0c_to_gm.hpp" +#include "catlass/gemm/tile/copy_l1_to_l0a.hpp" +#include "catlass/gemm/tile/copy_l1_to_l0b.hpp" + +#include "catlass/gemm/helper.hpp" +#include "catlass/gemv/helper.hpp" +#include "catlass/gemm/gemm_type.hpp" + +namespace Catlass::Gemv::Tile { + +template < + /// Tag indicating architecture + class ArchTag, + /// MatmulType for A matrix operand + class AType, + /// MatmulType type for X vector operand + class XType, + /// MatmulType type for Y vector operand + class YType, + /// MatmulTpe type for Bias operand + class BiasType = void> +struct TileCopyGemvAiv { + using ElementA = typename AType::Element; + using ElementX = typename XType::Element; + using ElementY = typename YType::Element; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + // the function of aiv + using VecCopyGmToUb = Gemv::Tile::VecCopyGmToUB; + static constexpr bool is_atoadd = Gemv::helper::AtomicAddSelector::value; + using VecCopyUbToGm = Gemv::Tile::VecCopyUBToGm; + using MatrixCopyGmToUb = Gemv::Tile::MatrixCopyGmToUB; +}; + +template < + /// Tag indicating architecture + class ArchTag, + /// MatmulType for A matrix operand + class AType, + /// MatmulType type for X vector operand + class XType, + /// MatmulType type for Y vector operand + class YType, + /// MatmulTpe type for Bias operand + class BiasType = void> +struct TileCopyGemvAic { + using ElementA = typename AType::Element; + using ElementX = typename XType::Element; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + // the function of aic + using L1XType = typename Gemv::helper::L1AndL0TypeSelectorGemv::L1AType; + using L1AType = typename Gemv::helper::L1AndL0TypeSelectorGemv::L1BType; + using L0AType = typename Gemv::helper::L1AndL0TypeSelectorGemv::L0AType; + using L0BType = typename Gemv::helper::L1AndL0TypeSelectorGemv::L0BType; + + using CopyGmToL1A = Gemm::Tile::CopyGmToL1; + using CopyGmToL1B = Gemm::Tile::CopyGmToL1; + + using CopyL1ToL0A = Gemm::Tile::CopyL1ToL0A; + using CopyL1ToL0B = Gemm::Tile::CopyL1ToL0B; + using CopyL0CToGm = Gemm::Tile::CopyL0CToGm; +}; + +} // namespace Catlass::Gemv::Tile + +#endif // CATLASS_GEMV_TILE_TILE_COPY_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/tile/tile_vmad.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/tile/tile_vmad.hpp new file mode 100644 index 00000000..f46aab1b --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/tile/tile_vmad.hpp @@ -0,0 +1,319 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMV_TILE_TILE_VMAD_HPP +#define CATLASS_GEMV_TILE_TILE_VMAD_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/gemm/helper.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/gemm/gemm_type.hpp" + +namespace Catlass::Gemv::Tile { + +template < + /// Tag indicating architecture + class ArchTag, class AType, class XType, class YType, class BiasType = void> +struct TileVmad { + static_assert(DEPENDENT_FALSE, "Unsupported TileVmad, can not find the specialization."); +}; + +template +struct TileVmad, + Gemm::GemmType, Gemm::GemmType, void> { + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + using LayoutDst = layout::RowMajor; + using LayoutSrc = layout::RowMajor; + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementA); + + // Methods + + CATLASS_DEVICE + TileVmad() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor_v, + AscendC::LocalTensor srcTensor_m, AscendC::LocalTensor temp, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + uint32_t m_actual = layoutSrc.shape(0); + uint32_t n_actual = layoutSrc.shape(1); + uint32_t m_round = layoutDst.shape(0); + uint32_t n_round = layoutDst.shape(1); + uint32_t temp_repeat_size = BYTE_PER_C0 * 8 / sizeof(ElementAccumulator); + uint32_t elem_repeat_size = ELE_NUM_PER_C0 * 8; + uint32_t mask = temp_repeat_size; + uint32_t repeattimes = CeilDiv(m_actual, temp_repeat_size); + AscendC::Duplicate(temp, (ElementAccumulator)0.0, temp_repeat_size, + CeilDiv(m_round * temp_repeat_size, temp_repeat_size), 1, 8); + + uint32_t repeat_num = n_actual / temp_repeat_size; + uint32_t remain = n_actual % temp_repeat_size; + + AscendC::PipeBarrier(); + AscendC::BinaryRepeatParams params; + params.dstBlkStride = 1; + params.src0BlkStride = 1; + params.src1BlkStride = 1; + params.dstRepStride = RoundUp(temp_repeat_size, temp_repeat_size) / (BYTE_PER_C0 / sizeof(ElementAccumulator)); + params.src0RepStride = RoundUp(n_round, elem_repeat_size) / ELE_NUM_PER_C0; + params.src1RepStride = 0; + AscendC::SetMaskCount(); + AscendC::SetVectorMask(m_actual * temp_repeat_size); + for (uint32_t i = 0; i < repeat_num; i++) { + uint32_t offset = i * temp_repeat_size; + AscendC::MulAddDst(temp, srcTensor_m[offset], srcTensor_v[offset], + AscendC::MASK_PLACEHOLDER, 1, params); + + AscendC::PipeBarrier(); + } + AscendC::SetMaskNorm(); + AscendC::ResetMask(); + + if (remain > 0) { + uint32_t offset = repeat_num * temp_repeat_size; + if (offset + remain > n_round) { + remain = n_round - offset; + } + uint64_t remain_mask = remain; + AscendC::MulAddDst(temp, srcTensor_m[offset], srcTensor_v[offset], + remain_mask, m_actual, params); + } + + uint64_t reduce_mask = (repeat_num == 0) ? remain : temp_repeat_size; + AscendC::PipeBarrier(); + AscendC::WholeReduceSum(temp, temp, reduce_mask, m_actual, 1, 1, 8); + AscendC::PipeBarrier(); + AscendC::UnaryRepeatParams castparams; + castparams.dstBlkStride = 1; + castparams.srcBlkStride = 1; + castparams.dstRepStride = 4; + castparams.srcRepStride = 8; + AscendC::Cast(srcTensor_m, temp, AscendC::RoundMode::CAST_NONE, + (uint64_t)mask, repeattimes, castparams); + AscendC::PipeBarrier(); + + uint64_t add_mask = (m_actual < elem_repeat_size) ? m_actual : elem_repeat_size; + params.dstRepStride = 8; + params.src0RepStride = 8; + params.src1RepStride = 8; + AscendC::Add(dstTensor, srcTensor_m, dstTensor, (uint64_t)add_mask, + CeilDiv(m_round, elem_repeat_size), params); + } +}; + +template <> +struct TileVmad, Gemm::GemmType, + Gemm::GemmType, void> { + using ElementA = float; + using ElementX = float; + using ElementY = float; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + using LayoutDst = layout::RowMajor; + using LayoutSrc = layout::RowMajor; + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementA); + + // Methods + + CATLASS_DEVICE + TileVmad() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor_v, + AscendC::LocalTensor srcTensor_m, AscendC::LocalTensor temp, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + uint32_t m_actual = layoutSrc.shape(0); + uint32_t n_actual = layoutSrc.shape(1); + uint32_t m_round = layoutDst.shape(0); + uint32_t n_round = layoutDst.shape(1); + + uint32_t repeat_size = ELE_NUM_PER_C0 * 8; + uint32_t mask = repeat_size; + uint32_t repeat_num = n_actual / repeat_size; + uint32_t remain = n_actual % repeat_size; + + AscendC::BinaryRepeatParams params; + params.dstBlkStride = 1; + params.src0BlkStride = 1; + params.src1BlkStride = 1; + params.dstRepStride = RoundUp(n_round, repeat_size) / ELE_NUM_PER_C0; + params.src0RepStride = RoundUp(n_round, repeat_size) / ELE_NUM_PER_C0; + params.src1RepStride = 0; + AscendC::SetMaskCount(); + AscendC::SetVectorMask(m_actual * repeat_size); + for (uint32_t i = 0; i < repeat_num; i++) { + uint32_t offset = i * repeat_size; + if (i == 0) { + AscendC::Mul(srcTensor_m, srcTensor_m, srcTensor_v, AscendC::MASK_PLACEHOLDER, 1, + params); + } else { + AscendC::MulAddDst(srcTensor_m, srcTensor_m[offset], srcTensor_v[offset], + AscendC::MASK_PLACEHOLDER, 1, params); + } + AscendC::PipeBarrier(); + } + AscendC::SetMaskNorm(); + AscendC::ResetMask(); + + if (remain > 0) { + uint32_t offset = repeat_num * repeat_size; + if (offset + remain > n_round) { + remain = n_round - offset; + } + uint64_t remain_mask = remain; + if (repeat_num == 0) { + AscendC::Mul(srcTensor_m, srcTensor_m, srcTensor_v, remain_mask, m_actual, params); + } else { + AscendC::MulAddDst(srcTensor_m, srcTensor_m[offset], srcTensor_v[offset], + remain_mask, m_actual, params); + } + } + + uint64_t reduce_mask = (repeat_num == 0) ? remain : repeat_size; + AscendC::PipeBarrier(); + AscendC::WholeReduceSum(srcTensor_m, srcTensor_m, reduce_mask, m_actual, 1, 1, + RoundUp(n_round, repeat_size) / ELE_NUM_PER_C0); + + uint64_t add_mask = (m_actual < repeat_size) ? m_actual : repeat_size; + params.dstRepStride = 8; + params.src0RepStride = 8; + params.src1RepStride = 8; + + AscendC::PipeBarrier(); + AscendC::Add(dstTensor, srcTensor_m, dstTensor, add_mask, CeilDiv(m_round, repeat_size), + params); + } +}; + +template +struct TileVmad, + Gemm::GemmType, Gemm::GemmType, void> { + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutDst = layout::ColumnMajor; + using LayoutSrc = layout::ColumnMajor; + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementA); + + // Methods + + CATLASS_DEVICE + TileVmad() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor_v, + AscendC::LocalTensor srcTensor_m, AscendC::LocalTensor temp, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + uint32_t m_actual = layoutSrc.shape(0); + uint32_t n_actual = layoutSrc.shape(1); + uint32_t m_round = layoutDst.shape(0); + uint32_t n_round = layoutDst.shape(1); + AscendC::SetMaskCount(); + AscendC::SetVectorMask(m_actual); + AscendC::Duplicate(temp, (ElementAccumulator)0.0, AscendC::MASK_PLACEHOLDER, 1, 1, + 8); + AscendC::PipeBarrier(); + + ElementX pix[32]; + AscendC::SetFlag((event_t)(0)); + AscendC::WaitFlag((event_t)(0)); + for (uint32_t i = 0; i < n_actual; i++) { + pix[i] = srcTensor_v.GetValue(i); + } + AscendC::SetFlag((event_t)(0)); + AscendC::WaitFlag((event_t)(0)); + + AscendC::UnaryRepeatParams params; + params.dstBlkStride = 1; + params.srcBlkStride = 1; + params.dstRepStride = 8; + params.srcRepStride = 4; + for (uint32_t i = 0; i < n_actual; i++) { + AscendC::Axpy(temp, srcTensor_m[i * m_round], pix[i], + AscendC::MASK_PLACEHOLDER, 1, params); + AscendC::PipeBarrier(); + } + params.dstRepStride = 4; + params.srcRepStride = 8; + AscendC::Cast(srcTensor_m, temp, AscendC::RoundMode::CAST_NONE, + AscendC::MASK_PLACEHOLDER, 1, params); + AscendC::BinaryRepeatParams addparams; + addparams.dstBlkStride = 1; + addparams.src0BlkStride = 1; + addparams.src1BlkStride = 1; + addparams.dstRepStride = 8; + addparams.src0RepStride = 8; + addparams.src1RepStride = 8; + AscendC::PipeBarrier(); + AscendC::Add(dstTensor, srcTensor_m, dstTensor, AscendC::MASK_PLACEHOLDER, 1, addparams); + AscendC::SetMaskNorm(); + AscendC::ResetMask(); + } +}; + +template <> +struct TileVmad, Gemm::GemmType, + Gemm::GemmType, void> { + using ElementA = float; + using ElementX = float; + using ElementY = float; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutDst = layout::ColumnMajor; + using LayoutSrc = layout::ColumnMajor; + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementA); + + // Methods + + CATLASS_DEVICE + TileVmad() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor_v, + AscendC::LocalTensor srcTensor_m, AscendC::LocalTensor temp, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + uint32_t m_actual = layoutSrc.shape(0); + uint32_t n_actual = layoutSrc.shape(1); + uint32_t m_round = layoutDst.shape(0); + uint32_t n_round = layoutDst.shape(1); + ElementX pix[32]; + AscendC::SetFlag((event_t)(0)); + AscendC::WaitFlag((event_t)(0)); + for (uint32_t i = 0; i < n_actual; i++) { + pix[i] = srcTensor_v.GetValue(i); + } + AscendC::SetFlag((event_t)(0)); + AscendC::WaitFlag((event_t)(0)); + AscendC::UnaryRepeatParams params; + params.dstBlkStride = 1; + params.srcBlkStride = 1; + params.dstRepStride = 8; + params.srcRepStride = 8; + AscendC::SetMaskCount(); + AscendC::SetVectorMask(m_actual); + for (uint32_t i = 0; i < n_actual; i++) { + AscendC::Axpy(dstTensor, srcTensor_m[i * m_round], pix[i], + AscendC::MASK_PLACEHOLDER, 1, params); + AscendC::PipeBarrier(); + } + AscendC::SetMaskNorm(); + AscendC::ResetMask(); + } +}; +} // namespace Catlass::Gemv::Tile + +#endif // CATLASS_GEMV_TILE_TILE_VMAD_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/tile/tile_vmuls.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/tile/tile_vmuls.hpp new file mode 100644 index 00000000..6ce05773 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/tile/tile_vmuls.hpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMV_TILE_VMULS_HPP +#define CATLASS_GEMV_TILE_VMULS_HPP + +#include "catlass/catlass.hpp" +#include "catlass/layout/layout.hpp" + +namespace Catlass::Gemv::Tile { + +template +struct TileVmuls { + using Element = typename VType_::Element; + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + TileVmuls() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, Element scalar, + uint32_t len) + { + AscendC::SetMaskCount(); + AscendC::SetVectorMask(len); + AscendC::Muls(dstTensor, srcTensor, scalar, AscendC::MASK_PLACEHOLDER, 1, + AscendC::UnaryRepeatParams{}); + AscendC::SetMaskNorm(); + AscendC::ResetMask(); + } +}; +} // namespace Catlass::Gemv::Tile + +#endif // CATLASS_GEMV_TILE_VMULS_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/tile/vec_copy_gm_to_ub.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/tile/vec_copy_gm_to_ub.hpp new file mode 100644 index 00000000..b4e6b04f --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/tile/vec_copy_gm_to_ub.hpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMV_TILE_VEC_COPY_GM_TO_UB_HPP +#define CATLASS_GEMV_TILE_VEC_COPY_GM_TO_UB_HPP + +#include "catlass/catlass.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/gemm/gemm_type.hpp" + +constexpr uint32_t STRIDE_LIMIT = 65536; + +namespace Catlass::Gemv::Tile { + +template +struct VecCopyGmToUB { + using Element = typename VType_::Element; + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + VecCopyGmToUB() {}; + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::GlobalTensor srcTensor, uint32_t len) + { + AscendC::DataCopyParams params; + params.blockCount = 1; + params.blockLen = CeilDiv(len, ELE_NUM_PER_C0); + params.srcStride = 0; + params.dstStride = 0; + AscendC::DataCopy(dstTensor, srcTensor, params); + } +}; +} // namespace Catlass::Gemv::Tile + +#endif // CATLASS_GEMV_TILE_VEC_COPY_GM_TO_UB_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/tile/vec_copy_ub_to_gm.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/tile/vec_copy_ub_to_gm.hpp new file mode 100644 index 00000000..f5a21302 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv/tile/vec_copy_ub_to_gm.hpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMV_TILE_VEC_COPY_UB_TO_GM_HPP +#define CATLASS_GEMV_TILE_VEC_COPY_UB_TO_GM_HPP + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/gemm/gemm_type.hpp" + +namespace Catlass::Gemv::Tile { + +template +struct VecCopyUBToGm { + static_assert(DEPENDENT_FALSE, "Unsupported copy UB to gm, can not find the specialization."); +}; + +template +struct VecCopyUBToGm> { + using LayoutSrc = layout::VectorLayout; + using LayoutDst = layout::VectorLayout; + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + VecCopyUBToGm() {}; + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor dstTensor, AscendC::LocalTensor srcTensor, + layout::VectorLayout const &layoutDst, layout::VectorLayout const &layoutSrc) + { + AscendC::DataCopyExtParams params; + params.blockCount = 1; + params.blockLen = layoutDst.shape(0) * sizeof(Element); + params.srcStride = 0; + params.dstStride = 0; + params.rsv = 0; + AscendC::DataCopyPad(dstTensor, srcTensor, params); + } +}; + +template +struct VecCopyUBToGm, true> { + using LayoutSrc = layout::VectorLayout; + using LayoutDst = layout::VectorLayout; + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + CATLASS_DEVICE + VecCopyUBToGm() {}; + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor dstTensor, AscendC::LocalTensor srcTensor, + layout::VectorLayout const &layoutDst, layout::VectorLayout const &layoutSrc) + { + AscendC::SetAtomicAdd(); + AscendC::DataCopyExtParams params; + params.blockCount = 1; + params.blockLen = layoutDst.shape(0) * sizeof(Element); + params.srcStride = 0; + params.dstStride = 0; + params.rsv = 0; + AscendC::DataCopyPad(dstTensor, srcTensor, params); + AscendC::SetAtomicNone(); + } +}; + +} // namespace Catlass::Gemv::Tile + +#endif // CATLASS_GEMV_TILE_VEC_COPY_UB_TO_GM_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv_coord.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv_coord.hpp new file mode 100644 index 00000000..c743dcfc --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/gemv_coord.hpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_GEMV_COORD_HPP +#define CATLASS_GEMV_COORD_HPP + +#include "catlass/coord.hpp" + +namespace Catlass { + +/// Shape of a matrix multiply-add operation +template < + /// Rows of matrix product + uint32_t M_ = 1, + /// Columns of the matrix (number of elements in the input vector) + uint32_t N_ = 1> +struct GemvShape { + static constexpr uint32_t M = M_; + static constexpr uint32_t N = N_; + + static constexpr int64_t MN = M * N; + + static constexpr int64_t COUNT = MN; + + /// Returns a Coord object + CATLASS_HOST_DEVICE + static Coord<2> ToCoord() + { + return MakeCoord(M, N); + } +}; + +/// GemvCoord is a structure derived from Coord<2> that specifies a location within the +/// coordinate space of a GEMV problem. +struct GemvCoord : public Coord<2, uint32_t> { + /// Integer-valued index + using Index = uint32_t; + + /// Base type is a Coord of rank=2 + using Base = Coord<2, Index>; + + /// GEMV M dimension - rows of the output vector (y) + static constexpr int M_INDEX = 0; + + /// GEMV N dimension - columns of the matrix (length of the input vector x) + static constexpr int N_INDEX = 1; + + /// Default ctor + CATLASS_HOST_DEVICE + GemvCoord() {} + + /// Constructs from Coord<2> and a batch + CATLASS_HOST_DEVICE + GemvCoord(Coord<2, Index> const &coord) : Base(coord) {} + + /// Helper to construct from M, N coordinates + CATLASS_HOST_DEVICE + GemvCoord(Index m, Index n) : Base(MakeCoord(m, n)) {} + + /// Returns the GEMV M coordinate (row of the result y) + CATLASS_HOST_DEVICE + Index const &m() const + { + return this->At(M_INDEX); + } + + /// Returns reference to the GEMV M coordinate + CATLASS_HOST_DEVICE + Index &m() + { + return this->At(M_INDEX); + } + + /// Returns the GEMV N coordinate (column of the matrix A or the input vector x) + CATLASS_HOST_DEVICE + Index const &n() const + { + return this->At(N_INDEX); + } + + /// Returns reference to the GEMV N coordinate + CATLASS_HOST_DEVICE + Index &n() + { + return this->At(N_INDEX); + } + + CATLASS_HOST_DEVICE + auto GetCoordMN() const + { + return this->GetCoordByAxis(); + } +}; + +} // namespace Catlass + +#endif // CATLASS_GEMV_COORD_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/layout/layout.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/layout/layout.hpp new file mode 100644 index 00000000..18275aae --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/layout/layout.hpp @@ -0,0 +1,18 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_LAYOUT_LAYOUT_HPP +#define CATLASS_LAYOUT_LAYOUT_HPP + +#include "catlass/catlass.hpp" +#include "catlass/layout/matrix.hpp" +#include "catlass/layout/vector.hpp" + +#endif // CATLASS_LAYOUT_LAYOUT_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/layout/matrix.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/layout/matrix.hpp new file mode 100644 index 00000000..73daa9d7 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/layout/matrix.hpp @@ -0,0 +1,1555 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_LAYOUT_MATRIX_HPP +#define CATLASS_LAYOUT_MATRIX_HPP + +#include "catlass/catlass.hpp" +#include "catlass/coord.hpp" +#include "catlass/detail/alignment.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/conv_coord.hpp" + +namespace Catlass::layout { + +/// Mapping function for row-major matrices +struct RowMajor { +public: + /// Logical rank of tensor + static constexpr int RANK = 2; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + /// Constructor + CATLASS_HOST_DEVICE + RowMajor(Index rows = 0, Index cols = 0) + : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(LongIndex(cols), LongIndex(1))) + {} + + /// Constructor + CATLASS_HOST_DEVICE + RowMajor(Index rows, Index cols, LongIndex ldm) + : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(ldm, LongIndex(1))) + {} + + /// Ctor + CATLASS_HOST_DEVICE + RowMajor(Shape shape, Stride stride) : shape_(shape), stride_(stride) {} + + template + CATLASS_HOST_DEVICE static RowMajor MakeLayout(Index rows, Index cols) + { + return RowMajor(rows, cols); + } + + template + CATLASS_HOST_DEVICE static RowMajor MakeLayoutInUb(MatrixCoord const &shape) + { + return RowMajor(shape.row(), shape.column(), RoundUp(shape.column())); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + CATLASS_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + return LongIndex(coord.row()) * stride_[0] + LongIndex(coord.column()); + } + + /// Returns the layout of a tile. + CATLASS_HOST_DEVICE + RowMajor GetTileLayout(MatrixCoord const &tileShape) const + { + return RowMajor(tileShape, stride()); + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + + /// Returns the length of the layout + CATLASS_HOST_DEVICE + size_t Capacity() + { + return static_cast(shape_[0]) * stride_[0]; + } + +private: + // + // Data members + // + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/// Mapping function for col-major matrices +struct ColumnMajor { +public: + /// Logical rank of tensor + static constexpr int RANK = 2; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + // Methods + + /// Constructor + CATLASS_HOST_DEVICE + ColumnMajor(Index rows = 0, Index cols = 0) + : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(LongIndex(1), LongIndex(rows))) + {} + + /// Constructor + CATLASS_HOST_DEVICE + ColumnMajor(Index rows, Index cols, LongIndex ldm) + : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(LongIndex(1), ldm)) + {} + + /// Ctor + CATLASS_HOST_DEVICE + ColumnMajor(Shape shape, Stride stride) : shape_(shape), stride_(stride) {} + + template + CATLASS_HOST_DEVICE static ColumnMajor MakeLayout(Index rows, Index cols) + { + return ColumnMajor(rows, cols); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + CATLASS_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + return LongIndex(coord.row()) + LongIndex(coord.column()) * stride_[1]; + } + + /// Returns the layout of a tile. + CATLASS_HOST_DEVICE + ColumnMajor GetTileLayout(MatrixCoord const &tileShape) const + { + return ColumnMajor(tileShape, stride()); + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + + /// Returns the length of the layout + CATLASS_HOST_DEVICE + size_t Capacity() + { + return static_cast(shape_[1]) * stride_[1]; + } + +private: + // + // Data members + // + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/// Mapping function for nZ matrices which is col-major inside fractal and row-major between fractal +struct nZ { +public: + /// Logical rank of tensor + static constexpr int RANK = 4; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + // Methods + + /// Constructor + CATLASS_HOST_DEVICE constexpr nZ( + Index orgRows = 0, /// Number of rows of origin matrices + Index orgCols = 0, /// Number of cols of origin matrices + Index rowsInFractal = 0, /// Number of rows inside the fractal + Index rowsByFractal = 0, /// number of rows by the fractal + Index colsInFractal = 0, /// number of cols inside the fractal + Index colsByFractal = 0, /// number of cols by the fractal + LongIndex strideRowsInFractal = 0, /// number of elements between adjacent rows inside the fractal + LongIndex strideRowsByFractal = 0, /// number of elements between adjacent fractal rows + LongIndex strideColsInFractal = 0, /// number of elements between adjacent cols inside the fractal + LongIndex strideColsByFractal = 0) /// number of elements between adjacent fractal cols + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, colsByFractal)), + stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, strideColsInFractal, strideColsByFractal)) + {} + + /// Ctor + CATLASS_HOST_DEVICE constexpr nZ(OrgShape orgShape, Shape shape, Stride stride) + : orgShape_(orgShape), shape_(shape), stride_(stride) + {} + + /// Make the layout of a coordinate (row, column) + template + CATLASS_HOST_DEVICE constexpr static nZ MakeLayout(Index orgRows, Index orgCols) + { + constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + Index rowsRound = RoundUp(orgRows); + Index colsRound = RoundUp(orgCols); + return nZ(orgRows, orgCols, ELE_NUM_PER_C0, rowsRound / ELE_NUM_PER_C0, C0_NUM_PER_FRACTAL, + colsRound / C0_NUM_PER_FRACTAL, 1, colsRound * ELE_NUM_PER_C0, ELE_NUM_PER_C0, ELE_NUM_PER_FRACTAL); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + CATLASS_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + return LongIndex(coord.row()) / shape_[0] * stride_[1] + LongIndex(coord.column()) / shape_[2] * stride_[3] + + (LongIndex(coord.row()) % shape_[0]) * stride_[0] + (LongIndex(coord.column()) % shape_[2]) * stride_[2]; + } + + /// Returns the layout of a tile. + CATLASS_HOST_DEVICE + nZ GetTileLayout(MatrixCoord const &tileOriShape) const + { + auto tileShape = MakeCoord(shape(0), CeilDiv(tileOriShape.row(), shape(0)), shape(2), + CeilDiv(tileOriShape.column(), shape(2))); + return nZ(tileOriShape, tileShape, stride()); + } + + /// Returns the origin shape of the layout + CATLASS_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + CATLASS_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + + /// Returns the length of the layout + CATLASS_HOST_DEVICE + size_t Capacity() + { + return static_cast(stride_[1]) * shape_[1]; + } + +private: + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/// Mapping function for zN matrices which is row-major inside fractal and col-major between fractal +struct zN { +public: + /// Logical rank of tensor + static constexpr int RANK = 4; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + // Methods + + /// Constructor + CATLASS_HOST_DEVICE constexpr zN( + Index orgRows = 0, /// Number of rows of origin matrices + Index orgCols = 0, /// Number of cols of origin matrices + Index rowsInFractal = 0, /// Number of rows inside the fractal + Index rowsByFractal = 0, /// number of rows by the fractal + Index colsInFractal = 0, /// number of cols inside the fractal + Index colsByFractal = 0, /// number of cols by the fractal + LongIndex strideRowsInFractal = 0, /// number of elements between adjacent rows inside the fractal + LongIndex strideRowsByFractal = 0, /// number of elements between adjacent fractal rows + LongIndex strideColsInFractal = 0, /// number of elements between adjacent cols inside the fractal + LongIndex strideColsByFractal = 0) /// number of elements between adjacent fractal cols + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, colsByFractal)), + stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, strideColsInFractal, strideColsByFractal)) + {} + + /// Ctor + CATLASS_HOST_DEVICE constexpr zN(OrgShape orgShape, Shape shape, Stride stride) + : orgShape_(orgShape), shape_(shape), stride_(stride) + {} + + /// Make the layout of a coordinate (row, column) + template + CATLASS_HOST_DEVICE constexpr static zN MakeLayout(Index orgRows, Index orgCols) + { + constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + Index rowsRound = RoundUp(orgRows); + Index colsRound = RoundUp(orgCols); + return zN(orgRows, orgCols, C0_NUM_PER_FRACTAL, rowsRound / C0_NUM_PER_FRACTAL, ELE_NUM_PER_C0, + colsRound / ELE_NUM_PER_C0, ELE_NUM_PER_C0, ELE_NUM_PER_FRACTAL, 1, rowsRound * ELE_NUM_PER_C0); + } + + CATLASS_HOST_DEVICE + static zN MakeLayoutInL0C(MatrixCoord const &shape) + { + return zN(shape.row(), shape.column(), C0_NUM_PER_FRACTAL, CeilDiv(shape.row()), + C0_NUM_PER_FRACTAL, CeilDiv(shape.column()), C0_NUM_PER_FRACTAL, + C0_NUM_PER_FRACTAL * C0_NUM_PER_FRACTAL, 1, + RoundUp(shape.row()) * C0_NUM_PER_FRACTAL); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + CATLASS_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + return LongIndex(coord.row()) / shape_[0] * stride_[1] + LongIndex(coord.column()) / shape_[2] * stride_[3] + + (LongIndex(coord.row()) % shape_[0]) * stride_[0] + (LongIndex(coord.column()) % shape_[2]) * stride_[2]; + } + + /// Returns the layout of a tile. + CATLASS_HOST_DEVICE + zN GetTileLayout(MatrixCoord const &tileOriShape) const + { + auto tileShape = MakeCoord(shape(0), CeilDiv(tileOriShape.row(), shape(0)), shape(2), + CeilDiv(tileOriShape.column(), shape(2))); + return zN(tileOriShape, tileShape, stride()); + } + + /// Returns the origin shape of the layout + CATLASS_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + CATLASS_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + + /// Returns the length of the layout + CATLASS_HOST_DEVICE + size_t Capacity() + { + return static_cast(stride_[3]) * shape_[3]; + } + +private: + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/// Mapping function for zN matrices which is row-major inside fractal and row-major between fractal +struct zZ { +public: + /// Logical rank of tensor + static constexpr int RANK = 4; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + // Methods + + /// Constructor + CATLASS_HOST_DEVICE constexpr zZ( + Index orgRows = 0, /// Number of rows of origin matrices + Index orgCols = 0, /// Number of cols of origin matrices + Index rowsInFractal = 0, /// Number of rows inside the fractal + Index rowsByFractal = 0, /// number of rows by the fractal + Index colsInFractal = 0, /// number of cols inside the fractal + Index colsByFractal = 0, /// number of cols by the fractal + LongIndex strideRowsInFractal = 0, /// number of elements between adjacent rows inside the fractal + LongIndex strideRowsByFractal = 0, /// number of elements between adjacent fractal rows + LongIndex strideColsInFractal = 0, /// number of elements between adjacent cols inside the fractal + LongIndex strideColsByFractal = 0) /// number of elements between adjacent fractal cols + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, colsByFractal)), + stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, strideColsInFractal, strideColsByFractal)) + {} + + /// Ctor + CATLASS_HOST_DEVICE constexpr zZ(OrgShape orgShape, Shape shape, Stride stride) + : orgShape_(orgShape), shape_(shape), stride_(stride) + {} + + /// Make the layout of a coordinate (row, column) + template + CATLASS_HOST_DEVICE constexpr static zZ MakeLayout(Index orgRows, Index orgCols) + { + constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + Index rowsRound = RoundUp(orgRows); + Index colsRound = RoundUp(orgCols); + return zZ(orgRows, orgCols, C0_NUM_PER_FRACTAL, rowsRound / C0_NUM_PER_FRACTAL, ELE_NUM_PER_C0, + colsRound / ELE_NUM_PER_C0, ELE_NUM_PER_C0, colsRound * C0_NUM_PER_FRACTAL, 1, ELE_NUM_PER_FRACTAL); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + CATLASS_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + return LongIndex(coord.row()) / shape_[0] * stride_[1] + LongIndex(coord.column()) / shape_[2] * stride_[3]; + } + + /// Returns the origin shape of the layout + CATLASS_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + CATLASS_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + +private: + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/// Mapping function for padding rowmajor matrices +/// A special data layout designed to improve the efficiency of matrix operations in non-512B aligned scenarios. +/// This layout is row-major within blocks and also row-major between blocks. +struct PaddingRowMajor { +public: + /// Logical rank of tensor + static constexpr int RANK = 4; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + /// Constructor + CATLASS_HOST_DEVICE + PaddingRowMajor(Index orgRows = 0, Index orgCols = 0, Index blockRows = 0, Index blockCols = 0) + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(blockRows, CeilDiv(orgRows, blockRows), blockCols, CeilDiv(orgCols, blockCols))), + stride_(MakeCoord((LongIndex)blockCols, (LongIndex)blockRows * (LongIndex)RoundUp(orgCols, blockCols), + (LongIndex)1, (LongIndex)blockRows * (LongIndex)blockCols)) + {} + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + CATLASS_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + LongIndex blockRows = (LongIndex)shape_[0]; + LongIndex blockCols = (LongIndex)shape_[2]; + return (LongIndex)coord.row() / blockRows * stride_[1] + (LongIndex)coord.column() / blockCols * stride_[3] + + (LongIndex)coord.row() % blockRows * stride_[0] + (LongIndex)coord.column() % blockCols; + } + + CATLASS_HOST_DEVICE + PaddingRowMajor GetTileLayout(MatrixCoord const &tileShape) const + { + return PaddingRowMajor(tileShape.row(), tileShape.column(), shape_[0], shape_[2]); + } + + /// Returns the origin shape of the layout + CATLASS_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + CATLASS_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + +private: + // + // Data members + // + + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/// Mapping function for padding columnmajor matrices +/// A special data layout designed to improve the efficiency of matrix operations in non-512B aligned scenarios. +/// This layout is column-major within blocks and also column-major between blocks. +struct PaddingColumnMajor { +public: + /// Logical rank of tensor + static constexpr int RANK = 4; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + /// Constructor + CATLASS_HOST_DEVICE + PaddingColumnMajor(Index orgRows = 0, Index orgCols = 0, Index blockRows = 0, Index blockCols = 0) + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(blockRows, CeilDiv(orgRows, blockRows), blockCols, CeilDiv(orgCols, blockCols))), + stride_(MakeCoord((LongIndex)1, (LongIndex)blockRows * (LongIndex)blockCols, (LongIndex)blockRows, + (LongIndex)RoundUp(orgRows, blockRows) * (LongIndex)blockCols)) + {} + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + CATLASS_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + LongIndex blockRows = (LongIndex)shape_[0]; + LongIndex blockCols = (LongIndex)shape_[2]; + return (LongIndex)coord.row() / blockRows * stride_[1] + (LongIndex)coord.column() / blockCols * stride_[3] + + (LongIndex)coord.row() % blockRows + (LongIndex)coord.column() % blockCols * stride_[2]; + } + + CATLASS_HOST_DEVICE + PaddingColumnMajor GetTileLayout(MatrixCoord const &tileShape) const + { + return PaddingColumnMajor(tileShape.row(), tileShape.column(), shape_[0], shape_[2]); + } + + /// Returns the origin shape of the layout + CATLASS_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + CATLASS_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + +private: + // + // Data members + // + + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/////////////////////// +// new add layout nN +// nN layout +struct nN { +public: + /// Logical rank of tensor + static constexpr int RANK = 4; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + // Methods + + /// Constructor + CATLASS_HOST_DEVICE + nN(Index orgRows = 0, /// Number of rows of origin matrices + Index orgCols = 0, /// Number of cols of origin matrices + + Index rowsInFractal = 0, /// Number of rows inside the fractal + Index rowsByFractal = 0, /// number of rows by the fractal + Index colsInFractal = 0, /// number of cols inside the fractal + Index colsByFractal = 0, /// number of cols by the fractal + + LongIndex strideRowsInFractal = 0, /// number of elements between adjacent rows inside the fractal + LongIndex strideRowsByFractal = 0, /// number of elements between adjacent fractal rows + LongIndex strideColsInFractal = 0, /// number of elements between adjacent cols inside the fractal + LongIndex strideColsByFractal = 0) /// number of elements between adjacent fractal cols + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, colsByFractal)), + stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, strideColsInFractal, strideColsByFractal)) + {} + + /// Ctor + CATLASS_HOST_DEVICE + nN(OrgShape orgShape, Shape shape, Stride stride) : orgShape_(orgShape), shape_(shape), stride_(stride) {} + + /// Make the layout of a coordinate (row, column) + template + CATLASS_HOST_DEVICE static nN MakeLayout(Index orgRows, Index orgCols) + { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + Index rowsRound = RoundUp(orgRows); + Index colsRound = RoundUp(orgCols); + return nN(orgRows, orgCols, + + ELE_NUM_PER_C0, rowsRound / ELE_NUM_PER_C0, C0_NUM_PER_FRACTAL, colsRound / C0_NUM_PER_FRACTAL, + + 1, ELE_NUM_PER_FRACTAL, ELE_NUM_PER_C0, rowsRound * C0_NUM_PER_FRACTAL); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + CATLASS_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + return LongIndex(coord.row()) / shape_[0] * stride_[1] + LongIndex(coord.column()) / shape_[2] * stride_[3]; + } + + /// Returns the origin shape of the layout + CATLASS_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + CATLASS_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + +private: + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +struct NDC1HWC0 { +public: + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical rank of orgshape + /// (N,D,C1,H,W,C0) + static constexpr int ORG_SHAPE_RANK = 6; + + static constexpr int RANK = 5; + using OrgShape = Coord; + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + // Methods + + /// Constructor + CATLASS_HOST_DEVICE constexpr NDC1HWC0(Index batch = 0, Index D = 0, Index C1 = 0, Index H = 0, Index W = 0, + Index C0 = 0, + + Index rowsInFractal = 0, /// Number of rows inside the fractal + Index rowsByFractal = 0, /// number of rows by the fractal + Index colsInFractal = 0, /// number of cols inside the fractal + Index colsByFractal = 0, /// number of cols by the fractal + + LongIndex strideC0 = 0, /// number of elements between adjacent C0 cols + LongIndex strideHW = 0, /// number of elements between adjacent W rows + LongIndex StrideC1 = 0, /// number of elements between adjacent C1 cols + LongIndex StrideD = 0, /// number of elements between adjacent D batchCols + LongIndex StrideN = 0 /// number of elements between adjacent batch + ) + : orgShape_(MakeCoord(batch, D, C1, H, W, C0)), + shape_(MakeCoord(batch, rowsInFractal, rowsByFractal, colsInFractal, colsByFractal)), + stride_(MakeCoord(strideC0, strideHW, StrideC1, StrideD, StrideN)) + {} + + /// Ctor + CATLASS_HOST_DEVICE constexpr NDC1HWC0(OrgShape orgshape, Shape shape, Stride stride) + : orgShape_(orgshape), shape_(shape), stride_(stride) + {} + + CATLASS_HOST_DEVICE constexpr static NDC1HWC0 MakeLayout(Index Batch, Index D, Index C1, Index H, Index W, Index C0) + { + return NDC1HWC0(Batch, D, C1, H, W, C0, + + W, H, C0, D * C1, + + 1, /// StrideC0 + C0, /// StrideHW + H * W * C0, /// StrideC1 + H * W * C0 * C1, /// StrideD + H * W * C0 * C1 * D /// StrideN + ); + } + + // CATLASS_HOST_DEVICE + /// Returns the offset of a coordinate in linear memory. + CATLASS_HOST_DEVICE + LongIndex GetOffset(Conv3d6HdCoord const &coord) const + { + return LongIndex(coord.n()) * stride_[4] + LongIndex(coord.d()) * stride_[3] + + LongIndex(coord.c1()) * stride_[2] + LongIndex(coord.hw()) * stride_[1]; + } + + /// Returns the layout of a tile. + CATLASS_HOST_DEVICE + NDC1HWC0 GetTileLayout(OrgShape const &tileOriShape) const + { + Shape tileShape = + MakeCoord(tileOriShape[0], tileOriShape[4], tileOriShape[3], shape(3), tileOriShape[1] * tileOriShape[2]); + + Stride tileStride = + MakeCoord(stride(0), stride(1), (LongIndex)(tileOriShape[3] * tileOriShape[4] * shape(3)), + (LongIndex)(tileOriShape[2] * tileOriShape[3] * tileOriShape[4] * shape(3)), + (LongIndex)(tileOriShape[1] * tileOriShape[2] * tileOriShape[3] * tileOriShape[4] * shape(3))); + return NDC1HWC0(tileOriShape, tileShape, tileStride); + } + + /// Returns the origin shape of the layout + CATLASS_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + CATLASS_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + +private: + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +struct KDC1KHKWN1N0C0 { +public: + static constexpr int RANK = 4; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 4; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + // Methods + + /// Constructor + CATLASS_HOST_DEVICE constexpr KDC1KHKWN1N0C0( + Index KdC1KhKw = 0, /// Merging Kd,Kh,Kw,C1 axes of KDC1KHKWN1N0C0 + Index N1 = 0, /// Cout = N1*N0 + Index N0 = 0, Index C0 = 0, + + Index rowsInFractal = 0, /// Number of rows inside the fractal + Index rowsByFractal = 0, /// number of rows by the fractal + Index colsInFractal = 0, /// number of cols inside the fractal + Index colsByFractal = 0, /// number of cols by the fractal + + LongIndex strideC0 = 0, /// number of elements between adjacent rows inside the fractal + LongIndex StrideDC1HW = 0, /// number of elements between adjacent fractal rows + LongIndex strideN0 = 0, /// number of elements between adjacent cols inside the fractal + LongIndex strideN1 = 0 /// number of elements between adjacent fractal cols + ) + : orgShape_(MakeCoord(KdC1KhKw, N1, N0, C0)), + shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, colsByFractal)), + stride_(MakeCoord(strideC0, strideN0, strideN1, StrideDC1HW)) + {} + + /// Ctor + CATLASS_HOST_DEVICE constexpr KDC1KHKWN1N0C0(OrgShape orgShape, Shape shape, Stride stride) + : orgShape_(orgShape), shape_(shape), stride_(stride) + {} + + /// Make the layout of a coordinate (Kd*C1*Kh*Kw,N1,N0,C0) + CATLASS_HOST_DEVICE constexpr static KDC1KHKWN1N0C0 MakeLayout(Index KdC1KhKw, Index N1, Index N0, Index C0) + { + return KDC1KHKWN1N0C0(KdC1KhKw, N1, N0, C0, + + C0, KdC1KhKw, N0, N1, + + 1, /// StrideC0 + C0 * N0 * N1, /// StrideDC1HW + C0, /// StrideN0 + C0 * N0 /// StrideN1 + ); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (KdC1KhKw_idx, N1_idx) + CATLASS_HOST_DEVICE + LongIndex GetOffset(Conv3dFracZ3dCoord const &coord) const + { + return LongIndex(coord.kdc1khkw()) * stride_[3] + LongIndex(coord.n1()) * stride_[1]; + } + + /// Returns the layout of a tile. + CATLASS_HOST_DEVICE + KDC1KHKWN1N0C0 GetTileLayout(OrgShape const &tileOriShape) const + { + Shape tileShape = MakeCoord(shape(0), /// C0 + tileOriShape[0], /// Kd*C1*Kh*Kw + shape(2), /// N0 + tileOriShape[1] /// N1 + ); + Stride tileStride = MakeCoord(stride(0), /// TileStrideC0 + stride(2) * tileOriShape[1] * tileOriShape[2], /// TileStrideDC1HW + (LongIndex)shape(0), /// TileStrideN0 + stride(2) * tileOriShape[2] /// TileStrideN1 + ); + return KDC1KHKWN1N0C0(tileOriShape, tileShape, tileStride); + } + + /// Returns the origin shape of the layout + CATLASS_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + CATLASS_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + +private: + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; +} // namespace Catlass::layout + +#endif // CATLASS_LAYOUT_MATRIX_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/layout/vector.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/layout/vector.hpp new file mode 100644 index 00000000..9590c5b3 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/layout/vector.hpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_LAYOUT_VECTOR_HPP +#define CATLASS_LAYOUT_VECTOR_HPP + +#include "catlass/catlass.hpp" +#include "catlass/coord.hpp" + +namespace Catlass::layout { + +struct VectorLayout { +public: + /// Logical rank of tensor + static constexpr int RANK = 1; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Shape vector + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + + /// Logical coordinate + using TensorCoord = Coord; + +public: + // Methods + + CATLASS_HOST_DEVICE + VectorLayout(Index size = 0) : shape_(MakeCoord(size)), stride_(MakeCoord(LongIndex(1))) {} + + CATLASS_HOST_DEVICE + VectorLayout(Shape shape, Stride stride) : shape_(shape), stride_(stride) {} + + template + CATLASS_HOST_DEVICE static VectorLayout MakeLayoutInUb(TensorCoord const &tileShape) + { + return VectorLayout{RoundUp(tileShape[0])}; + } + + CATLASS_HOST_DEVICE + LongIndex GetOffset(TensorCoord const &coord) const + { + return stride_[0] * coord[0]; + } + + /// Returns the layout of a tile. + CATLASS_HOST_DEVICE + VectorLayout GetTileLayout(TensorCoord const &tileShape) const + { + return VectorLayout(tileShape, stride()); + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + CATLASS_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + CATLASS_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } + +private: + /// Stride data member + Shape shape_; + Stride stride_; +}; + +} // namespace Catlass::layout + +#endif // CATLASS_LAYOUT_VECTOR_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/matrix_coord.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/matrix_coord.hpp new file mode 100644 index 00000000..156b1233 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/matrix_coord.hpp @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_MATRIX_COORD_HPP +#define CATLASS_MATRIX_COORD_HPP + +#include "catlass/coord.hpp" + +namespace Catlass { + +template +struct MatrixShape { + static constexpr uint32_t ROW = ROW_; + static constexpr uint32_t COLUMN = COLUMN_; + + static constexpr int64_t COUNT = ROW * COLUMN; + + CATLASS_HOST_DEVICE + static Coord<2> ToCoord() + { + return MakeCoord(ROW, COLUMN); + } +}; + +/// MatrixCoord wraps Coord<2, uint32_t> to provide a helper for accessing named dimensions. Classes +/// expecting a coordinate in the rank=2 index space of a matrix should use MatrixCoord. +struct MatrixCoord : public Coord<2, uint32_t> { + /// Integer-valued index + using Index = uint32_t; + + /// Base type is a Coord of rank=2 + using Base = Coord<2, Index>; + + /// LongIndex type + using LongIndex = typename Base::LongIndex; + + /// Rows dimension + static constexpr uint32_t ROW_INDEX = 0; + + /// Columns dimension + static constexpr uint32_t COLUMN_INDEX = 1; + + /// Default ctor + CATLASS_HOST_DEVICE + MatrixCoord() {} + + /// Constructs from Coord<2> + CATLASS_HOST_DEVICE + MatrixCoord(Coord<2, Index> const &coord) : Base(coord) {} + + /// Helper to construct from a row and column + CATLASS_HOST_DEVICE + MatrixCoord(Index row, Index column) : Base(MakeCoord(row, column)) {} + + /// Helper to construct from a row and column, which are LongIndex based + CATLASS_HOST_DEVICE + MatrixCoord(LongIndex row, LongIndex column) : Base(MakeCoord(Index(row), Index(column))) {} + + /// Returns the row of the coordinate + CATLASS_HOST_DEVICE + Index const &row() const + { + return this->At(ROW_INDEX); + } + + /// Returns the row of the coordinate + CATLASS_HOST_DEVICE + Index &row() + { + return this->At(ROW_INDEX); + } + + /// Returns the column of the coordinate + CATLASS_HOST_DEVICE + Index const &column() const + { + return this->At(COLUMN_INDEX); + } + + /// Returns the column of the coordinate + CATLASS_HOST_DEVICE + Index &column() + { + return this->At(COLUMN_INDEX); + } + + /// Element-wise addition + CATLASS_HOST_DEVICE + MatrixCoord operator+(Base const &b) const + { + return MatrixCoord(Base::operator+(b)); + } + + /// In-place addition + CATLASS_HOST_DEVICE + MatrixCoord &operator+=(Base const &b) + { + Base::operator+=(b); + return *this; + } +}; + +} // namespace Catlass + +#endif diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/status.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/status.hpp new file mode 100644 index 00000000..cb2cb349 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/catlass/status.hpp @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CATLASS_STATUS_HPP +#define CATLASS_STATUS_HPP + +namespace Catlass { + +enum class Status { kSuccess, kInvalid }; + +} // namespace Catlass + +#endif // CATLASS_STATUS_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/int_tuple.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/int_tuple.hpp new file mode 100644 index 00000000..93216ae8 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/int_tuple.hpp @@ -0,0 +1,237 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TLA_INT_TUPLE_HPP +#define TLA_INT_TUPLE_HPP + +#include "tla/type_traits.hpp" +#include "tla/tuple.hpp" +#include "tla/numeric/integral_constant.hpp" +#include "tla/numeric/integer_sequence.hpp" + +namespace tla { +// +// Apply (Unpack) +// (t, f) => f(t_0,t_1,...,t_n) +// + +namespace detail { + +template +CATLASS_HOST_DEVICE constexpr auto apply(T &&t, F &&f, seq) +{ + return f(get(static_cast(t))...); +} + +template +CATLASS_HOST_DEVICE constexpr auto tapply(T &&t, F &&f, G &&g, seq) +{ + return g(f(get(static_cast(t)))...); +} + +template +CATLASS_HOST_DEVICE constexpr auto tapply(T0 &&t0, T1 &&t1, F &&f, G &&g, seq) +{ + return g(f(get(static_cast(t0)), get(static_cast(t1)))...); +} + +} // end namespace detail + +template +CATLASS_HOST_DEVICE constexpr auto apply(T &&t, F &&f) +{ + return detail::apply(static_cast(t), f, tuple_seq{}); +} + +template +CATLASS_HOST_DEVICE constexpr auto transform_apply(T &&t, F &&f, G &&g) +{ + if constexpr (is_tuple>::value) { + return detail::tapply(static_cast(t), f, g, tuple_seq{}); + } else { + return g(f(static_cast(t))); + } +} + +struct UnpackedMakeTuple { + template + CATLASS_HOST_DEVICE constexpr auto operator()(T const &...a) const + { + return tla::MakeTuple(a...); + } +}; + +template +CATLASS_HOST_DEVICE constexpr auto transform(T0 const &t0, T1 const &t1, F &&f) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + return detail::tapply(t0, t1, f, UnpackedMakeTuple{}, tuple_seq{}); + } else { + return f(t0, t1); + } +} + +template >::value)> +CATLASS_HOST_DEVICE constexpr decltype(auto) get(T &&t) noexcept +{ + static_assert(I == 0, "Index out of range"); + return static_cast(t); +} + +template +CATLASS_HOST_DEVICE constexpr decltype(auto) get(T &&t) noexcept +{ + return get(get(static_cast(t))); +} + +// max +template +CATLASS_HOST_DEVICE constexpr auto max(T0 const &t0, Ts const &...ts); + +struct UnpackedMax { + template + CATLASS_HOST_DEVICE constexpr auto operator()(T const &...v) const + { + return tla::max(v...); + } +}; + +template +CATLASS_HOST_DEVICE constexpr auto max(T0 const &t0, Ts const &...ts) +{ + if constexpr (is_tuple::value) { + return tla::max(tla::apply(t0, UnpackedMax{}), ts...); + } else if constexpr (sizeof...(Ts) == 0) { + return t0; + } else { + return tla::max(t0, tla::max(ts...)); + } +} + +// rank +template +CATLASS_HOST_DEVICE constexpr auto rank(Tuple const &t) +{ + if constexpr (sizeof...(Is) == 0) { + if constexpr (is_tuple::value) { + return Int::value>{}; + } else { + return Int<1>{}; + } + } else { + return rank(get(t)); + } +} + +template +using rank_t = decltype(rank(std::declval())); + +template +constexpr auto rank_v = rank_t::value; + +// depth +template +CATLASS_HOST_DEVICE constexpr auto depth(Tuple const &t); + +struct UnpackedDepth { + template + CATLASS_HOST_DEVICE constexpr auto operator()(T const &...v) const + { + return tla::max(depth(v)...); + } +}; + +template +CATLASS_HOST_DEVICE constexpr auto depth(Tuple const &t) +{ + if constexpr (sizeof...(Is) == 0) { + if constexpr (is_tuple::value) { + return Int<1>{} + tla::apply(t, UnpackedDepth{}); + } else { + return Int<0>{}; + } + } else { + return depth(get(t)); + } +} + +template +using depth_t = decltype(depth(std::declval())); + +template +constexpr auto depth_v = depth_t::value; + +struct MultipliesUnaryLfold { + template + CATLASS_HOST_DEVICE constexpr auto operator()(T const &...v) const + { + return (... * v); + } +}; + +// Implementation of product as a function object +struct Product { + template + CATLASS_HOST_DEVICE constexpr auto operator()(IntTuple const &a) const + { + if constexpr (is_tuple::value) { + if constexpr (tuple_size::value == 0) { + return Int<1>{}; + } else { + return tla::transform_apply(a, Product{}, MultipliesUnaryLfold{}); + } + } else if constexpr (tla::is_integral::value) { + return a; + } + } +}; + +namespace detail { + +template +struct MakeZeroTupleImpl; + +template +struct MakeZeroTupleImpl> { + using type = tla::tuple...>; +}; + +template +using MakeZeroTuple = typename MakeZeroTupleImpl>::type; + +} // end namespace detail + +// Add +template +CATLASS_HOST_DEVICE constexpr auto Add(IntTupleA const &a, IntTupleB const &b); + +struct UnpackedAdd { + template + CATLASS_HOST_DEVICE constexpr auto operator()(IntTupleA const &x, IntTupleB const &y) const + { + return Add(x, y); + } +}; + +template +CATLASS_HOST_DEVICE constexpr auto Add(IntTupleA const &a, IntTupleB const &b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + return transform(a, b, UnpackedAdd{}); + } else { + return tla::add(a, b); + } +} + +} // end namespace tla + +#endif // TLA_INT_TUPLE_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/layout.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/layout.hpp new file mode 100644 index 00000000..68b3e3e6 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/layout.hpp @@ -0,0 +1,375 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TLA_LAYOUT_HPP +#define TLA_LAYOUT_HPP + +#include "catlass/catlass.hpp" +#include "tla/numeric/integral_constant.hpp" +#include "tla/tuple.hpp" +#include "tla/int_tuple.hpp" +#include "catlass/layout/layout.hpp" + +namespace tla { + +// Aliases + +template +using Shape = tla::tuple; + +template +using Stride = tla::tuple; + +template +using Coord = tla::tuple; + +template +CATLASS_HOST_DEVICE constexpr Shape MakeShape(Ts const &...t) +{ + return {t...}; +} +template +CATLASS_HOST_DEVICE constexpr Stride MakeStride(Ts const &...t) +{ + return {t...}; +} +template +CATLASS_HOST_DEVICE constexpr Coord MakeCoord(Ts const &...t) +{ + return {t...}; +} + +// +// Layout +// + +template +struct Layout : private tla::tuple { + // NOTE: This defaults static Shapes/Strides correctly, but not dynamic + CATLASS_HOST_DEVICE constexpr Layout(Shape const &shape = {}, Stride const &stride = {}) + : tla::tuple(shape, stride) + {} + + // + // Accessors + // + + static constexpr int rank = rank_v; + static constexpr int depth = depth_v; + + template + CATLASS_HOST_DEVICE constexpr decltype(auto) shape() + { + return get<0, I...>(static_cast &>(*this)); + } + + template + CATLASS_HOST_DEVICE constexpr decltype(auto) shape() const + { + return get<0, I...>(static_cast const &>(*this)); + } + + template + CATLASS_HOST_DEVICE constexpr decltype(auto) stride() + { + return get<1, I...>(static_cast &>(*this)); + } + + template + CATLASS_HOST_DEVICE constexpr decltype(auto) stride() const + { + return get<1, I...>(static_cast const &>(*this)); + } + + template + CATLASS_HOST_DEVICE constexpr auto operator()(Coord const &coord) const + { + return crd2offset(coord, shape(), stride()); + } +}; + +// Layout construction + +template +CATLASS_HOST_DEVICE constexpr auto MakeLayout(Shape const &shape, Stride const &stride) +{ + static_assert(is_tuple::value || is_integral::value); + static_assert(is_tuple::value || is_integral::value); + return Layout(shape, stride); +} + +// Convenience tags for common layouts + +template +CATLASS_HOST_DEVICE constexpr auto MakeLayoutFromTag(LayoutTag const &tag) +{ + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || std::is_same_v, + "Unsupported LayoutTag for MakeLayoutFromTag, only support Catlass::layout::RowMajor or" + "Catlass::layout::ColumnMajor or Catlass::layout::zN or Catlass::layout::nZ"); + + if constexpr (std::is_same_v) { + return MakeLayout(MakeShape(tag.shape(0), tag.shape(1)), MakeStride(tag.stride(0), Int<1>{})); + } else if constexpr (std::is_same_v) { + return MakeLayout(MakeShape(tag.shape(0), tag.shape(1)), MakeStride(Int<1>{}, tag.stride(1))); + } else { // zN or nZ + return MakeLayout( + MakeShape(MakeShape(tag.shape(0), tag.shape(1)), MakeShape(tag.shape(2), tag.shape(3))), + MakeStride(MakeStride(tag.stride(0), tag.stride(1)), MakeStride(tag.stride(2), tag.stride(3)))); + } +} + +// Return the shape of a mode +template +CATLASS_HOST_DEVICE constexpr decltype(auto) shape(Layout &layout) +{ + return layout.template shape(); +} + +template +CATLASS_HOST_DEVICE constexpr decltype(auto) shape(Layout const &layout) +{ + return layout.template shape(); +} + +// Return the stride of a mode +template +CATLASS_HOST_DEVICE constexpr decltype(auto) stride(Layout &layout) +{ + return layout.template stride(); +} + +template +CATLASS_HOST_DEVICE constexpr decltype(auto) stride(Layout const &layout) +{ + return layout.template stride(); +} + +// Return the rank of layout +template +CATLASS_HOST_DEVICE constexpr auto rank(Layout const &layout) +{ + return rank(shape(layout)); +} + +// Return the depth of the layout +template +CATLASS_HOST_DEVICE constexpr auto depth(Layout const &layout) +{ + return depth(shape(layout)); +} + +// Return the offset of coord +template +CATLASS_HOST_DEVICE constexpr auto crd2offset(Coord const &coord, Shape const &shape, Stride const &stride); + +namespace detail { + +template +CATLASS_HOST_DEVICE constexpr auto crd2offset_ttt(Coord const &coord, Shape const &shape, Stride const &stride, + seq) +{ + return (... + crd2offset(get(coord), get(shape), get(stride))); +} + +template +CATLASS_HOST_DEVICE constexpr auto crd2offset_itt(CInt const &coord, STuple const &shape, DTuple const &stride, + seq) +{ + if constexpr (sizeof...(Is) == 0) { // Avoid recursion and mod on single/last iter + return crd2offset(coord, get(shape), get(stride)); + } else if constexpr (is_constant<0, CInt>::value) { + return crd2offset(_0{}, get(shape), get(stride)) + + (_0{} + ... + crd2offset(_0{}, get(shape), get(stride))); + } else { // General case + return crd2offset(coord % Product{}(get(shape)), get(shape), get(stride)) + + crd2offset_itt(coord / Product{}(get(shape)), shape, stride, seq{}); + } +} + +} // end namespace detail + +template +CATLASS_HOST_DEVICE constexpr auto crd2offset(Coord const &coord, Shape const &shape, Stride const &stride) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple tuple + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return detail::crd2offset_ttt(coord, shape, stride, tuple_seq{}); + } else { // tuple "int" "int" + static_assert(sizeof(Coord) == 0, "Invalid parameters"); + } + } else { + if constexpr (is_tuple::value) { // "int" tuple tuple + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return detail::crd2offset_itt(coord, shape, stride, tuple_seq{}); + } else { // "int" "int" "int" + return coord * stride; + } + } +} + +template +struct is_layout : false_type {}; +template +struct is_layout> : true_type {}; + +// Layout Check +namespace detail { + +template +struct isRowMajor { + static bool const value = false; +}; + +template +struct isRowMajor> { + static bool const value = (stride<1>(Layout{}) == 1); +}; + +template +struct isColumnMajor { + static bool const value = false; +}; + +template +struct isColumnMajor> { + static bool const value = (stride<0>(Layout{}) == 1); +}; + +template +struct iszN { + static bool const value = false; +}; + +template +struct iszN> { + static constexpr uint32_t ELE_NUM_PER_C0 = Catlass::BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = Catlass::BYTE_PER_FRACTAL / sizeof(Element); + static bool const value = + (shape<0, 0>(Layout{}) == Catlass::C0_NUM_PER_FRACTAL && shape<1, 0>(Layout{}) == ELE_NUM_PER_C0 && + stride<1, 0>(Layout{}) == 1 && stride<0, 1>(Layout{}) == ELE_NUM_PER_FRACTAL); +}; + +template +struct iszZ { + static bool const value = false; +}; + +template +struct iszZ> { + static constexpr uint32_t ELE_NUM_PER_C0 = Catlass::BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = Catlass::BYTE_PER_FRACTAL / sizeof(Element); + static bool const value = + (shape<0, 0>(Layout{}) == Catlass::C0_NUM_PER_FRACTAL && shape<1, 0>(Layout{}) == ELE_NUM_PER_C0 && + stride<1, 0>(Layout{}) == 1 && stride<1, 1>(Layout{}) == ELE_NUM_PER_FRACTAL); +}; + +template +struct isnZ { + static bool const value = false; +}; + +template +struct isnZ> { + static constexpr uint32_t ELE_NUM_PER_C0 = Catlass::BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = Catlass::BYTE_PER_FRACTAL / sizeof(Element); + static bool const value = + (shape<0, 0>(Layout{}) == ELE_NUM_PER_C0 && shape<1, 0>(Layout{}) == Catlass::C0_NUM_PER_FRACTAL && + stride<0, 0>(Layout{}) == 1 && stride<1, 1>(Layout{}) == ELE_NUM_PER_FRACTAL); +}; + +} // end namespace detail + +// Advanced Layout constructions +// Make a inner layout with Rows and Cols. +template +CATLASS_HOST_DEVICE constexpr auto MakeLayout(T const &rows, U const &cols) +{ + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || std::is_same_v, + "Unsupported LayoutTag for MakeLayoutFromTag, only support Catlass::layout::RowMajor or" + "Catlass::layout::ColumnMajor or Catlass::layout::zN or Catlass::layout::nZ or Catlass::layout::zZ"); + + constexpr uint32_t ELE_NUM_PER_C0 = Catlass::BYTE_PER_C0 / sizeof(Element); + constexpr uint32_t ELE_NUM_PER_FRACTAL = Catlass::BYTE_PER_FRACTAL / sizeof(Element); + + if constexpr (std::is_same_v) { + return MakeLayout(MakeShape(rows, cols), MakeStride(cols, Int<1>{})); + } else if constexpr (std::is_same_v) { + return MakeLayout(MakeShape(rows, cols), MakeStride(Int<1>{}, rows)); + } else if constexpr (std::is_same_v) { + return MakeLayout( + MakeShape(MakeShape(Int{}, CeilDiv(rows, Int{})), + MakeShape(Int{}, CeilDiv(cols, Int{}))), + MakeStride(MakeStride(Int{}, Int{}), + MakeStride(Int<1>{}, RoundUp(rows, Int{}) * ELE_NUM_PER_C0))); + } else if constexpr (std::is_same_v) { + return MakeLayout( + MakeShape(MakeShape(Int{}, CeilDiv(rows, Int{})), + MakeShape(Int{}, CeilDiv(cols, Int{}))), + MakeStride( + MakeStride(Int{}, RoundUp(cols, Int{}) * Catlass::C0_NUM_PER_FRACTAL), + MakeStride(Int<1>{}, Int{}))); + } else { + return MakeLayout( + MakeShape(MakeShape(Int{}, CeilDiv(rows, Int{})), + MakeShape(Int{}, CeilDiv(cols, Int{}))), + MakeStride(MakeStride(Int<1>{}, RoundUp(cols, Int{}) * ELE_NUM_PER_C0), + MakeStride(Int{}, Int{}))); + } +} + +template +CATLASS_HOST_DEVICE constexpr auto MakeLayoutTile(Layout const &layout, ShapeNew const &shapeNew) +{ + static_assert(is_tuple::value && depth_v == 1 && rank_v == 2); + + if constexpr (Layout::depth == 1 && Layout::rank == 2) { + return MakeLayout(shapeNew, layout.stride()); + } else if constexpr (is_static(layout))>::value && + is_static(layout))>::value) { + const uint32_t rows = get<0>(shapeNew); + const uint32_t cols = get<1>(shapeNew); + constexpr uint32_t dstInnerShapeRow = decltype(shape<0, 0>(layout))::value; + constexpr uint32_t dstInnerShapeCol = decltype(shape<1, 0>(layout))::value; + return MakeLayout(MakeShape(MakeShape(Int{}, CeilDiv(rows)), + MakeShape(Int{}, CeilDiv(cols))), + layout.stride()); + } else { + const uint32_t rows = get<0>(shapeNew); + const uint32_t cols = get<1>(shapeNew); + const uint32_t dstInnerShapeRow = shape<0, 0>(layout); + const uint32_t dstInnerShapeCol = shape<1, 0>(layout); + return MakeLayout(MakeShape(MakeShape(dstInnerShapeRow, CeilDiv(rows, dstInnerShapeRow)), + MakeShape(dstInnerShapeCol, CeilDiv(cols, dstInnerShapeCol))), + layout.stride()); + } +} + +template +CATLASS_HOST_DEVICE constexpr auto MakeLayoutL0C(T const &rows, U const &cols) +{ + constexpr uint32_t ELE_NUM_PER_FRACTAL = 256; + return MakeLayout( + MakeShape(MakeShape(Int{}, CeilDiv(rows, Int{})), + MakeShape(Int{}, CeilDiv(cols, Int{}))), + MakeStride( + MakeStride(Int{}, Int{}), + MakeStride(Int<1>{}, RoundUp(rows, Int{}) * Catlass::C0_NUM_PER_FRACTAL))); +} + +} // end namespace tla + +#endif // TLA_LAYOUT_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/numeric/integer_sequence.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/numeric/integer_sequence.hpp new file mode 100644 index 00000000..45519bf5 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/numeric/integer_sequence.hpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TLA_NUMERIC_INTEGER_SEQUENCE_HPP +#define TLA_NUMERIC_INTEGER_SEQUENCE_HPP + +#include "tla/numeric/integral_constant.hpp" +#include "tla/type_traits.hpp" + +namespace tla { + +template +struct IntegerSequence { + using value_type = T; + static constexpr size_t size() + { + return sizeof...(Ns); + } +}; + +template +struct MakeIntegerSequenceImpl; + +template +struct MakeIntegerSequenceImpl, T, 0> { + typedef IntegerSequence type; +}; + +template +struct MakeIntegerSequenceImpl, T, N> { + typedef typename MakeIntegerSequenceImpl, T, N - 1>::type type; +}; + +template +using MakeIntegerSequence = typename MakeIntegerSequenceImpl, T, N>::type; + +// index_sequence +template +using index_sequence = IntegerSequence; + +template +using make_index_sequence = MakeIntegerSequence; + +// int_sequence +template +using int_sequence = IntegerSequence; + +template +using make_int_sequence = MakeIntegerSequence; + +// Shortcuts +template +using seq = int_sequence; + +template +using make_seq = make_int_sequence; + +template +using tuple_seq = make_seq>::value>; + +} // end namespace tla + +#endif // TLA_NUMERIC_INTEGER_SEQUENCE_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/numeric/integral_constant.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/numeric/integral_constant.hpp new file mode 100644 index 00000000..e70e9511 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/numeric/integral_constant.hpp @@ -0,0 +1,178 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TLA_NUMERIC_INTEGER_CONSTANT_HPP +#define TLA_NUMERIC_INTEGER_CONSTANT_HPP + +#include "catlass/detail/macros.hpp" +#include "tla/type_traits.hpp" +#include "tla/numeric/math.hpp" + +namespace tla { + +// A constant value: short name and type-deduction for fast compilation +template +struct C { + using type = C; + static constexpr auto value = v; + using value_type = decltype(v); + CATLASS_HOST_DEVICE constexpr operator value_type() const noexcept + { + return value; + } + CATLASS_HOST_DEVICE constexpr value_type operator()() const noexcept + { + return value; + } +}; + +// Deprecate +template +using constant = C; + +template +using bool_constant = C; + +using true_type = bool_constant; +using false_type = bool_constant; + +template +using is_std_integral = std::is_integral; + +// A more std:: conforming integral_constant that enforces type but interops with C +template +struct integral_constant : C { + using type = integral_constant; + static constexpr T value = v; + using value_type = T; + CATLASS_HOST_DEVICE constexpr value_type operator()() const noexcept + { + return value; + } +}; + +// Use tla::is_std_integral to match built-in integral types (int, int64_t, unsigned, etc) +// Use tla::is_integral to match both built-in integral types AND static integral types. + +template +struct is_integral : bool_constant::value> {}; +template +struct is_integral> : true_type {}; +template +struct is_integral> : true_type {}; + +// is_static detects if an (abstract) value is defined completely by its type (no members) +template +struct is_static : bool_constant>::value> {}; + +// is_constant detects if a type is a static integral type and if v is equal to a value + +template +struct is_constant : false_type {}; +template +struct is_constant : is_constant {}; +template +struct is_constant : is_constant {}; +template +struct is_constant : is_constant {}; +template +struct is_constant : is_constant {}; +template +struct is_constant> : bool_constant {}; +template +struct is_constant> : bool_constant {}; + +// +// Specializations +// + +template +using Int = C; +using _0 = Int<0>; +using _64 = Int<64>; +using _128 = Int<128>; +using _256 = Int<256>; +using _512 = Int<512>; + +/***************/ +/** Operators **/ +/***************/ + +#define TLA_LEFT_UNARY_OP(OP) \ + template \ + CATLASS_HOST_DEVICE constexpr C<(OP t)> operator OP(C) \ + { \ + return {}; \ + } +#define TLA_BINARY_OP(OP) \ + template \ + CATLASS_HOST_DEVICE constexpr C<(t OP u)> operator OP(C, C) \ + { \ + return {}; \ + } + +TLA_LEFT_UNARY_OP(+); +TLA_LEFT_UNARY_OP(-); +TLA_LEFT_UNARY_OP(~); +TLA_LEFT_UNARY_OP(!); +TLA_LEFT_UNARY_OP(*); + +TLA_BINARY_OP(+); +TLA_BINARY_OP(-); +TLA_BINARY_OP(*); +TLA_BINARY_OP(/); +TLA_BINARY_OP(%); +TLA_BINARY_OP(&); +TLA_BINARY_OP(|); +TLA_BINARY_OP(^); +TLA_BINARY_OP(<<); +TLA_BINARY_OP(>>); + +#undef TLA_BINARY_OP +#undef TLA_LEFT_UNARY_OP +#undef TLA_RIGHT_UNARY_OP + +// +// Named functions from math.hpp +// + +#define TLA_NAMED_UNARY_FN(OP) \ + template \ + CATLASS_HOST_DEVICE constexpr auto OP(C) \ + { \ + return C{}; \ + } +#define TLA_NAMED_BINARY_FN(OP) \ + template \ + CATLASS_HOST_DEVICE constexpr auto OP(C, C) \ + { \ + return C{}; \ + } \ + template ::value)> \ + CATLASS_HOST_DEVICE constexpr auto OP(C, U u) \ + { \ + return OP(t, u); \ + } \ + template ::value)> \ + CATLASS_HOST_DEVICE constexpr auto OP(T t, C) \ + { \ + return OP(t, u); \ + } + +TLA_NAMED_BINARY_FN(max); +TLA_NAMED_BINARY_FN(min); +TLA_NAMED_BINARY_FN(add); + +#undef TLA_NAMED_UNARY_FN +#undef TLA_NAMED_BINARY_FN + +} // end namespace tla + +#endif // TLA_NUMERIC_INTEGER_CONSTANT_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/numeric/math.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/numeric/math.hpp new file mode 100644 index 00000000..231888c2 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/numeric/math.hpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TLA_NUMERIC_MATH_HPP +#define TLA_NUMERIC_MATH_HPP + +#include "catlass/detail/macros.hpp" +#include "tla/type_traits.hpp" + +namespace tla { + +// +// Common Operations +// + +template ::value &&std::is_arithmetic::value)> +CATLASS_HOST_DEVICE constexpr auto max(T const &t, U const &u) +{ + return t < u ? u : t; +} + +template ::value &&std::is_arithmetic::value)> +CATLASS_HOST_DEVICE constexpr auto min(T const &t, U const &u) +{ + return t < u ? t : u; +} + +template ::value &&std::is_arithmetic::value)> +CATLASS_HOST_DEVICE constexpr auto add(T const &t, U const &u) +{ + return t + u; +} + +} // namespace tla + +#endif // TLA_NUMERIC_MATH_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/tensor.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/tensor.hpp new file mode 100644 index 00000000..341ac8f1 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/tensor.hpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TLA_TENSOR_HPP +#define TLA_TENSOR_HPP + +#include "catlass/arch/arch.hpp" +#include "tla/layout.hpp" // tla::Shape +#include "tla/numeric/integral_constant.hpp" // tla::is_integral +#include "tla/int_tuple.hpp" + +namespace tla { +// +// Tensor +// + +template +struct Tensor { + using Element = typename BuiltinTensor::PrimType; + using Layout = Layout_; + using Coord = Coord_; + static constexpr AscendC::TPosition position = Position; + + CATLASS_HOST_DEVICE constexpr Tensor() {} + + CATLASS_HOST_DEVICE constexpr Tensor(BuiltinTensor const &builtinTensor, Layout const &layout, + Coord const &coord = {}) + : rep_(builtinTensor, layout, coord) + {} + + // + // Accessors + // + + static constexpr int rank = Layout::rank; + + CATLASS_HOST_DEVICE constexpr decltype(auto) tensor() const + { + return *this; + } + + CATLASS_HOST_DEVICE constexpr decltype(auto) data() const + { + return get<0>(rep_); + } + + CATLASS_HOST_DEVICE constexpr decltype(auto) data() + { + return get<0>(rep_); + } + + CATLASS_HOST_DEVICE constexpr decltype(auto) layout() const + { + return get<1>(rep_); + } + + CATLASS_HOST_DEVICE constexpr decltype(auto) coord() const + { + return get<2>(rep_); + } + + CATLASS_HOST_DEVICE constexpr decltype(auto) shape() const + { + return layout().shape(); + } + + CATLASS_HOST_DEVICE constexpr decltype(auto) stride() const + { + return layout().stride(); + } + + tla::tuple rep_; +}; + +template +CATLASS_HOST_DEVICE constexpr auto MakeTensor(BuiltinTensor const &builtinTensor, Layout const &layout, PositionType) +{ + using Coord = detail::MakeZeroTuple; + return Tensor(builtinTensor, layout); +} + +template +CATLASS_HOST_DEVICE constexpr auto MakeTensor(BuiltinTensor const &builtinTensor, Layout const &layout, + Coord const &coord, PositionType) +{ + return Tensor(builtinTensor, layout, coord); +} + +template +CATLASS_DEVICE constexpr auto GetTile(Tensor const &tensor, Coord const &coord, Shape const &shape) +{ + auto layout = tensor.layout(); + auto builtinTensor = tensor.data(); + auto layoutNew = MakeLayoutTile(layout, shape); + auto coordNew = Add(tensor.coord(), coord); + return MakeTensor(builtinTensor, layoutNew, coordNew, Catlass::Arch::PositionType{}); +} + +} // end namespace tla + +#endif // TLA_TENSOR_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/tuple.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/tuple.hpp new file mode 100644 index 00000000..edea3d50 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/tuple.hpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TLA_TUPLE_HPP +#define TLA_TUPLE_HPP + +#include "tla/numeric/integral_constant.hpp" +#include "tla/numeric/integer_sequence.hpp" + +namespace tla { + +namespace detail { + +// EBO stands for "empty base optimization." +template ::value> +struct EBO; + +// Specialization for types T that are empty; +template +struct EBO { + CATLASS_HOST_DEVICE constexpr EBO() {} + + CATLASS_HOST_DEVICE constexpr EBO(T const &) {} +}; + +template +CATLASS_HOST_DEVICE constexpr T getv(EBO const &) +{ + return {}; +} + +// Specialization for types T that are not empty; +template +struct EBO { + CATLASS_HOST_DEVICE constexpr EBO() : t_{} {} + + CATLASS_HOST_DEVICE constexpr EBO(T const &t) : t_{t} {} + + T t_; +}; + +template +CATLASS_HOST_DEVICE constexpr T const &getv(EBO const &x) +{ + return x.t_; +} + +template +CATLASS_HOST_DEVICE constexpr T &getv(EBO &x) +{ + return x.t_; +} + +// TupleBase +template +struct TupleBase; + +template +struct TupleBase, T...> : EBO... { + CATLASS_HOST_DEVICE constexpr TupleBase() {} + + CATLASS_HOST_DEVICE constexpr TupleBase(T const &...t) : EBO(t)... {} +}; + +} // end namespace detail + +// tla::tuple class. +template +struct tuple : detail::TupleBase, T...> { + CATLASS_HOST_DEVICE constexpr tuple() {} + + CATLASS_HOST_DEVICE constexpr tuple(T const &...t) + : detail::TupleBase, T...>(t...) + {} +}; + +// get for tla::tuple +template +CATLASS_HOST_DEVICE constexpr decltype(auto) get(tuple const &t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(t); +} + +template +CATLASS_HOST_DEVICE constexpr decltype(auto) get(tuple &t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(t); +} + +template +CATLASS_HOST_DEVICE constexpr decltype(auto) get(tuple &&t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(static_cast &&>(t)); +} + +namespace detail { + +template +auto has_tuple_size(T *) -> bool_constant<(0 <= tuple_size::value)>; +auto has_tuple_size(...) -> false_type; + +} // end namespace detail + +template +struct is_tuple : decltype(detail::has_tuple_size((T *)0)){}; + +template +struct tuple_size> : std::integral_constant {}; + +template +struct tuple_size> : std::integral_constant {}; + +// make_tuple +template +CATLASS_HOST_DEVICE constexpr tuple MakeTuple(T const &...t) +{ + return {t...}; +} + +} // end namespace tla + +#endif // TLA_TUPLE_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/type_traits.hpp b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/type_traits.hpp new file mode 100644 index 00000000..68c566e6 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/catlass/tla/type_traits.hpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef TLA_UTIL_TYPE_TRAITS_HPP +#define TLA_UTIL_TYPE_TRAITS_HPP + +#pragma push_macro("inline") +#include +#pragma pop_macro("inline") + +#define __TLA_REQUIRES(...) typename std::enable_if<(__VA_ARGS__)>::type * = nullptr + +namespace tla { + +// using std::remove_cvref; +template +struct remove_cvref { + using type = std::remove_cv_t>; +}; + +// using std::remove_cvref_t; +template +using remove_cvref_t = typename remove_cvref::type; + +// tuple_size, tuple_element +template +struct tuple_size; + +template +struct tuple_size::type>> + : std::integral_constant::value> {}; + +template +constexpr size_t tuple_size_v = tuple_size::value; + +} // end namespace tla + +#endif // TLA_UTIL_TYPE_TRAITS_HPP diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/epilogue/block/block_epilogue.h b/csrc/deepep/ops2/utils/op_kernel/operator/epilogue/block/block_epilogue.h new file mode 100644 index 00000000..ac97b0de --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/epilogue/block/block_epilogue.h @@ -0,0 +1,13 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "../../catlass/catlass/epilogue/block/block_epilogue.hpp" + +#include "block_epilogue_per_token_dequant_swiglu.h" diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/epilogue/block/block_epilogue_per_token_dequant_swiglu.h b/csrc/deepep/ops2/utils/op_kernel/operator/epilogue/block/block_epilogue_per_token_dequant_swiglu.h new file mode 100644 index 00000000..54c4865f --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/epilogue/block/block_epilogue_per_token_dequant_swiglu.h @@ -0,0 +1,326 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "../../catlass/catlass/catlass.hpp" +#include "../../catlass/catlass/arch/resource.hpp" +#include "../../catlass/catlass/epilogue/dispatch_policy.hpp" +#include "../../catlass/catlass/gemm_coord.hpp" +#include "../../catlass/catlass/matrix_coord.hpp" +#include "../../catlass/catlass/layout/layout.hpp" +#include "../../catlass/catlass/detail/callback.hpp" + +#include "../../epilogue/tile/tile_stride_muls.h" +#include "../../epilogue/tile/tile_stride_binary.h" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue, CType_, + Gemm::GemmType, Gemm::GemmType, DType_, + TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_, + EpilogueTileSwizzle_> +{ +public: + using DispatchPolicy = EpilogueAtlasA2PerTokenDequantSwiglu; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementScale = float; + using LayoutScale = LayoutScale_; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert(std::is_same_v && std::is_same_v, + "The element type template parameters of BlockEpilogue are wrong"); + static_assert(std::is_same_v && std::is_same_v && + std::is_same_v && + std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong"); + + // Tile compute ops + using TileRowBroadcastMul = TileRowBroadcastMul_; + using TileBroadcastOneBlk = TileBroadcastOneBlk_; + using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; + using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + using EpilogueTileSwizzle = EpilogueTileSwizzle_; + + using TileShape = typename TileRowBroadcastMul::TileShape; + static_assert(TileShape::ROW * sizeof(float) % BYTE_PER_BLK == 0, + "The per token scale granularity for word calculation must be 32 bytes aligned."); + static_assert(TileShape::COLUMN % 2 == 0, "The n-axis needs to be divided into two parts."); + + static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && + std::is_same_v, + "TileShape must be consistent for all tile compute ops"); + + static constexpr uint32_t CHUNK_TILE_COLUMN = TileShape::COLUMN / 2; + using ChunkTileShape = MatrixShape; + + using TileStrideMuls = Tile::TileStrideMuls; + using TileStrideDiv = Tile::TileStrideDiv; + using TileStrideMul = Tile::TileStrideMul; + + static_assert(UB_STAGES <= 2, "UB stages too large, event id is not enough."); + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + + TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + + (TileShape::COUNT + ChunkTileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <= + ArchTag::UB_SIZE, + "TileShape is too large to fit in UB"); + + struct Params { + __gm__ ElementScale *ptrScale{nullptr}; + LayoutScale layoutScale{}; + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + CATLASS_DEVICE + Params() {}; + + CATLASS_DEVICE + Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, + __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_) + : ptrScale(ptrScale_), + layoutScale(layoutScale_), + ptrPerTokenScale(ptrPerTokenScale_), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(ptrD_), + layoutD(layoutD_) + {} + }; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + { + size_t ubOffset = 0; + int32_t eventVMTE2 = 0; + int32_t eventMTE2V = 0; + int32_t eventMTE3V = 0; + int32_t eventVMTE3 = 0; + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementC); + ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementScale); + ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbScaleVMTE2List[i] = eventVMTE2++; + eventUbScaleMTE2VList[i] = eventMTE2V++; + eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; + eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbScaleVMTE2List[i]); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + } + ubTmpMxN = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubTmpMx32B = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * BYTE_PER_BLK; + ubTmpMxChunkN = resource.ubBuf.template GetBufferByByte(ubOffset); + } + + CATLASS_DEVICE + ~BlockEpilogue() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } + } + + CATLASS_DEVICE + void UpdateParams(Params const ¶ms_) + { + params = params_; + } + + CATLASS_DEVICE + void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK, + GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor const &gmBlockC, + LayoutC const &layoutBlockC, Callback &&callback = Callback{}) + { + if (0 == actualBlockShapeMNK.k()) { + return; + } + callback(); + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmScale; + gmScale.SetGlobalBuffer(params.ptrScale); + AscendC::GlobalTensor gmPerTokenScale; + gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(params.ptrD); + + auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); + auto ubChunkTileStride = MakeCoord(static_cast(ChunkTileShape::COLUMN), 1L); + auto tileShape = TileShape::ToCoord(); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = 0; // 原本是AscendC::GetSubBlockIdx(); + uint32_t subblockNum = 1; // 原本是AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto actualChunkTileShape = MakeCoord(actualTileShape.row(), actualTileShape.column() >> 1); + auto chunkTileOffset = MakeCoord(tileOffset.row(), tileOffset.column() >> 1); + + auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; + auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); + + auto &ubC = ubCList[ubListId]; + LayoutC layoutUbC{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + + auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); + auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); + + auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; + auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); + + auto &ubScale = ubScaleList[ubListId]; + auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); + + AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); + copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); + AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); + + auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); + auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>(); + + auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)]; + auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); + + auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; + auto layoutUbPerTokenScale = + LayoutScale::template MakeLayoutInUb(perTokenScaleTileShape); + + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale, + layoutGmTilePerTokenScale); + AscendC::SetFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + AscendC::Cast(ubTmpMxN, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); + tileRowBroadcastMul(ubTmpMxN, ubTmpMxN, ubScale); + AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); + AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + tileBroadcastOneBlk(ubTmpMx32B, ubPerTokenScale); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + + AscendC::PipeBarrier(); + tileOneBlkColumnBroadcastMul(ubTmpMxN, ubTmpMxN, ubTmpMx32B); + AscendC::PipeBarrier(); + tileStrideMuls(ubTmpMxChunkN, ubTmpMxN, -1.0f); + AscendC::PipeBarrier(); + AscendC::Exp(ubTmpMxChunkN, ubTmpMxChunkN, ChunkTileShape::COUNT); + AscendC::PipeBarrier(); + AscendC::Adds(ubTmpMxChunkN, ubTmpMxChunkN, 1.0f, ChunkTileShape::COUNT); + AscendC::PipeBarrier(); + tileStrideDiv(ubTmpMxChunkN, ubTmpMxN, ubTmpMxChunkN); + AscendC::PipeBarrier(); + auto &ubD = ubDList[ubListId]; + LayoutD layoutUbD{actualChunkTileShape, ubChunkTileStride}; + + auto ubTmpMxNR = ubTmpMxN[ChunkTileShape::COLUMN]; + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + tileStrideMul(ubD, ubTmpMxNR, ubTmpMxChunkN); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + + auto gmTileD = gmD[params.layoutD.GetOffset(chunkTileOffset)]; + auto layoutGmTileD = params.layoutD.GetTileLayout(actualChunkTileShape); + + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + } + +private: + Params params; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubScaleList[UB_STAGES]; + AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbScaleVMTE2List[UB_STAGES]; + int32_t eventUbScaleMTE2VList[UB_STAGES]; + int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; + int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + uint32_t ubListId{0}; + + AscendC::LocalTensor ubTmpMxN; + AscendC::LocalTensor ubTmpMx32B; + AscendC::LocalTensor ubTmpMxChunkN; + + TileRowBroadcastMul tileRowBroadcastMul; + TileBroadcastOneBlk tileBroadcastOneBlk; + TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; + + TileStrideMuls tileStrideMuls; + TileStrideDiv tileStrideDiv; + TileStrideMul tileStrideMul; + + CopyGmToUbC copyGmToUbC; + CopyGmToUbScale copyGmToUbScale; + CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; + CopyUbToGmD copyUbToGmD; +}; + +} // namespace Catlass::Epilogue::Block diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/epilogue/dispatch_policy.h b/csrc/deepep/ops2/utils/op_kernel/operator/epilogue/dispatch_policy.h new file mode 100644 index 00000000..35aa6e13 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/epilogue/dispatch_policy.h @@ -0,0 +1,22 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "../catlass/catlass/epilogue/dispatch_policy.hpp" + +namespace Catlass::Epilogue { + +template +struct EpilogueAtlasA2PerTokenDequantSwiglu { + using ArchTag = Catlass::Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; +}; + +} // namespace Catlass::Epilogue diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/epilogue/tile/tile_stride_binary.h b/csrc/deepep/ops2/utils/op_kernel/operator/epilogue/tile/tile_stride_binary.h new file mode 100644 index 00000000..4f5e7cf0 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/epilogue/tile/tile_stride_binary.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "../../catlass/catlass/catlass.hpp" + +namespace Catlass::Epilogue::Tile { + +template +struct TileStrideBinary { + using ArchTag = ArchTag_; + using ElementCompute = ElementCompute_; + using TileShape = TileShape_; + static constexpr int64_t DST_STRIDE = DST_STRIDE_; + static constexpr int64_t SRC0_STRIDE = SRC0_STRIDE_; + static constexpr int64_t SRC1_STRIDE = SRC1_STRIDE_; + + static constexpr uint32_t MAX_REPEAT_TIMES = 255; + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(ElementCompute); + + static constexpr uint32_t DST_BLK_NUM_PER_COLUMN = DST_STRIDE / ELE_NUM_PER_BLK; + static constexpr uint32_t SRC0_BLK_NUM_PER_COLUMN = SRC0_STRIDE / ELE_NUM_PER_BLK; + static constexpr uint32_t SRC1_BLK_NUM_PER_COLUMN = SRC1_STRIDE / ELE_NUM_PER_BLK; + + static constexpr uint32_t ROW_NUM_PER_COMPUTE = MAX_REPEAT_TIMES; + static constexpr uint32_t COL_NUM_PER_COMPUTE = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); + + CATLASS_DEVICE + TileStrideBinary() + { + repeatParams.dstBlkStride = 1; + repeatParams.src0BlkStride = 1; + repeatParams.src1BlkStride = 1; + repeatParams.dstRepStride = DST_BLK_NUM_PER_COLUMN; + repeatParams.src0RepStride = SRC0_BLK_NUM_PER_COLUMN; + repeatParams.src1RepStride = SRC1_BLK_NUM_PER_COLUMN; + } + + AscendC::BinaryRepeatParams repeatParams; +}; + +template +struct TileStrideMul + : TileStrideBinary { + using Base = TileStrideBinary; + + CATLASS_DEVICE + TileStrideMul() : Base() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &ubDst, + AscendC::LocalTensor const &ubSrc0, + AscendC::LocalTensor const &ubSrc1) + { + for (uint32_t rowOffset = 0; rowOffset < Base::TileShape::ROW; rowOffset += Base::ROW_NUM_PER_COMPUTE) { + uint32_t residueM = Base::TileShape::ROW - rowOffset; + uint8_t repeatTimes = + static_cast((residueM > Base::ROW_NUM_PER_COMPUTE) ? Base::ROW_NUM_PER_COMPUTE : residueM); + for (uint32_t colOffset = 0; colOffset < Base::TileShape::COLUMN; colOffset += Base::COL_NUM_PER_COMPUTE) { + uint32_t residueN = Base::TileShape::COLUMN - colOffset; + uint64_t mask = (residueN > Base::COL_NUM_PER_COMPUTE) ? Base::COL_NUM_PER_COMPUTE : residueN; + AscendC::Mul(ubDst[rowOffset * Base::DST_STRIDE + colOffset], + ubSrc0[rowOffset * Base::SRC0_STRIDE + colOffset], + ubSrc1[rowOffset * Base::SRC1_STRIDE + colOffset], mask, repeatTimes, this->repeatParams); + } + } + } +}; + +template +struct TileStrideDiv + : TileStrideBinary { + using Base = TileStrideBinary; + + CATLASS_DEVICE + TileStrideDiv() : Base() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &ubDst, + AscendC::LocalTensor const &ubSrc0, + AscendC::LocalTensor const &ubSrc1) + { + for (uint32_t rowOffset = 0; rowOffset < Base::TileShape::ROW; rowOffset += Base::ROW_NUM_PER_COMPUTE) { + uint32_t residueM = Base::TileShape::ROW - rowOffset; + uint8_t repeatTimes = + static_cast((residueM > Base::ROW_NUM_PER_COMPUTE) ? Base::ROW_NUM_PER_COMPUTE : residueM); + for (uint32_t colOffset = 0; colOffset < Base::TileShape::COLUMN; colOffset += Base::COL_NUM_PER_COMPUTE) { + uint32_t residueN = Base::TileShape::COLUMN - colOffset; + uint64_t mask = (residueN > Base::COL_NUM_PER_COMPUTE) ? Base::COL_NUM_PER_COMPUTE : residueN; + AscendC::Div(ubDst[rowOffset * Base::DST_STRIDE + colOffset], + ubSrc0[rowOffset * Base::SRC0_STRIDE + colOffset], + ubSrc1[rowOffset * Base::SRC1_STRIDE + colOffset], mask, repeatTimes, this->repeatParams); + } + } + } +}; + +} // namespace Catlass::Epilogue::Tile diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/epilogue/tile/tile_stride_muls.h b/csrc/deepep/ops2/utils/op_kernel/operator/epilogue/tile/tile_stride_muls.h new file mode 100644 index 00000000..bb49bb8c --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/epilogue/tile/tile_stride_muls.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "../../catlass/catlass/catlass.hpp" + +namespace Catlass::Epilogue::Tile { + +template +struct TileStrideMuls { + using ArchTag = ArchTag_; + using ElementCompute = ElementCompute_; + using TileShape = TileShape_; + using DstTileShape = DstTileShape_; + using SrcTileShape = SrcTileShape_; + + static_assert(DstTileShape::ROW == SrcTileShape::ROW && DstTileShape::ROW == TileShape::ROW, "Error"); + + CATLASS_DEVICE + TileStrideMuls() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &ubDst, + AscendC::LocalTensor const &ubSrc, ElementCompute scalar) + { + constexpr uint32_t maxRepeatTimes = 255; + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + + constexpr uint32_t dstBlkNumPerColumn = DstTileShape::COLUMN / eleNumPerBlk; + constexpr uint32_t srcBlkNumPerColumn = SrcTileShape::COLUMN / eleNumPerBlk; + AscendC::UnaryRepeatParams repeatParams; + repeatParams.dstBlkStride = 1; + repeatParams.srcBlkStride = 1; + repeatParams.dstRepStride = dstBlkNumPerColumn; + repeatParams.srcRepStride = srcBlkNumPerColumn; + + constexpr uint32_t rowNumPerCompute = maxRepeatTimes; + constexpr uint32_t colNumPerCompute = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); + for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; rowOffset += rowNumPerCompute) { + uint32_t residueM = TileShape::ROW - rowOffset; + uint8_t repeatTimes = static_cast((residueM > rowNumPerCompute) ? rowNumPerCompute : residueM); + for (uint32_t colOffset = 0; colOffset < TileShape::COLUMN; colOffset += colNumPerCompute) { + uint32_t residueN = TileShape::COLUMN - colOffset; + uint64_t mask = (residueN > colNumPerCompute) ? colNumPerCompute : residueN; + AscendC::Muls(ubDst[rowOffset * DstTileShape::COLUMN + colOffset], + ubSrc[rowOffset * SrcTileShape::COLUMN + colOffset], scalar, mask, repeatTimes, + repeatParams); + } + } + } +}; + +} // namespace Catlass::Epilogue::Tile diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/gemm/block/block_mmad.h b/csrc/deepep/ops2/utils/op_kernel/operator/gemm/block/block_mmad.h new file mode 100644 index 00000000..e03140be --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/gemm/block/block_mmad.h @@ -0,0 +1,13 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "../../catlass/catlass/gemm/block/block_mmad.hpp" + +#include "block_mmad_preload_async_with_callback_resident_a.h" diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/gemm/block/block_mmad_preload_async_with_callback_resident_a.h b/csrc/deepep/ops2/utils/op_kernel/operator/gemm/block/block_mmad_preload_async_with_callback_resident_a.h new file mode 100644 index 00000000..278c7079 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/gemm/block/block_mmad_preload_async_with_callback_resident_a.h @@ -0,0 +1,420 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "../../catlass/catlass/catlass.hpp" +#include "../../catlass/catlass/arch/resource.hpp" +#include "../../catlass/catlass/coord.hpp" +#include "../../catlass/catlass/detail/callback.hpp" +#include "../../catlass/catlass/gemm_coord.hpp" +#include "../../catlass/catlass/gemm/dispatch_policy.hpp" +#include "../../catlass/catlass/gemm/helper.hpp" + +namespace Catlass::Gemm::Block { + +template +struct BlockMmad< + MmadAtlasA2PreloadAsyncWithCallbackResidentA, + L1TileShape_, L0TileShape_, AType_, BType_, CType_, BiasType_, TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = + MmadAtlasA2PreloadAsyncWithCallbackResidentA; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t PRELOAD_STAGES = DispatchPolicy::PRELOAD_STAGES; + static constexpr uint32_t L1A_STAGES = DispatchPolicy::L1A_STAGES; + static constexpr uint32_t L1B_STAGES = DispatchPolicy::L1B_STAGES; + static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES; + static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES; + static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES; + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K; + + // L1 tile size + static constexpr uint32_t L1A_TILE_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_TILE_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + // L0 tile size + static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA); + static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB); + static constexpr uint32_t L0C_TILE_SIZE = L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator); + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + // Check L1TileShape + static_assert(L1A_TILE_SIZE * L1A_STAGES + L1B_TILE_SIZE * L1B_STAGES <= ArchTag::L1_SIZE, + "L1TileShape exceeding the L1 space!"); + + // Check L0TileShape + static_assert(L0A_TILE_SIZE * L0A_STAGES <= ArchTag::L0A_SIZE, "L0TileShape exceeding the L0A space!"); + static_assert(L0B_TILE_SIZE * L0B_STAGES <= ArchTag::L0B_SIZE, "L0TileShape exceeding the L0B space!"); + static_assert(L0C_TILE_SIZE * L0C_STAGES <= ArchTag::L0C_SIZE, "L0TileShape exceeding the L0C space!"); + + static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N, + "The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet"); + + static constexpr auto L1A_LAYOUT = LayoutAInL1::template MakeLayout(L1TileShape::M, L1TileShape::K); + static constexpr auto L1B_LAYOUT = LayoutBInL1::template MakeLayout(L1TileShape::K, L1TileShape::N); + + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + InitL1(resource, l1BufAddrStart); + InitL0A(resource); + InitL0B(resource); + InitL0C(resource); + } + + CATLASS_DEVICE + ~BlockMmad() + { + SynchronizeBlock(); + for (uint32_t i = 0; i < L1A_STAGES; ++i) { + AscendC::WaitFlag(l1AEventList[i]); + } + for (uint32_t i = 0; i < L1B_STAGES; ++i) { + AscendC::WaitFlag(l1BEventList[i]); + } + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + AscendC::WaitFlag(l0AEventList[i]); + } + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + AscendC::WaitFlag(l0BEventList[i]); + } + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + AscendC::WaitFlag(l0CEventList[i]); + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &gmBlockA, LayoutA const &layoutA, + AscendC::GlobalTensor const &gmBlockB, LayoutB const &layoutB, + AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutC, + GemmCoord const &actualShape, Callback const &callbackBeforeFixpipe, + Callback const &callbackAfterFixpipe) + { + uint32_t kTileCount = CeilDiv(actualShape.k()); + bool useResidentA = + (kTileCount == L1A_STAGES) && (!isFirstLoad) && (gmBlockA.GetPhyAddr() == lastGmBlockA.GetPhyAddr()); + isFirstLoad = false; + lastGmBlockA = gmBlockA; + + uint32_t mRound = RoundUp(actualShape.m()); + uint32_t nRound = RoundUp(actualShape.n()); + + uint32_t startTileIdx = 0; + if constexpr (ENABLE_SHUFFLE_K) { + startTileIdx = AscendC::GetBlockIdx() % kTileCount; + } + + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; ++kLoopIdx) { + uint32_t kTileIdx = (startTileIdx + kLoopIdx < kTileCount) ? (startTileIdx + kLoopIdx) + : (startTileIdx + kLoopIdx - kTileCount); + + uint32_t kActual = + (kTileIdx < kTileCount - 1) ? L1TileShape::K : (actualShape.k() - kTileIdx * L1TileShape::K); + + // Emission load instruction from GM to L1 + MatrixCoord gmTileAOffset{0, kTileIdx * L1TileShape::K}; + MatrixCoord gmTileBOffset{kTileIdx * L1TileShape::K, 0}; + auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)]; + // Load first matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1AListId]); + if (!useResidentA) { + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); + copyGmToL1A(l1ATensorList[l1AListId], gmTileA, L1A_LAYOUT, layoutTileA); + } + AscendC::SetFlag(l1AEventList[l1AListId]); + // Load first matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1BListId]); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); + copyGmToL1B(l1BTensorList[l1BListId], gmTileB, L1B_LAYOUT, layoutTileB); + AscendC::SetFlag(l1BEventList[l1BListId]); + + // If the number of preload instructions reaches the upper limit, perform an mmad calculation on L1 tile + if (preloadCount == PRELOAD_STAGES) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + } + + // Store the current load status + uint32_t preloadL1TileMmadParamsId = (l1TileMmadParamsId + preloadCount < PRELOAD_STAGES) + ? (l1TileMmadParamsId + preloadCount) + : (l1TileMmadParamsId + preloadCount - PRELOAD_STAGES); + auto &l1TileMmadParams = l1TileMmadParamsList[preloadL1TileMmadParamsId]; + l1TileMmadParams.l1AListId = l1AListId; + l1TileMmadParams.l1BListId = l1BListId; + l1TileMmadParams.mRound = mRound; + l1TileMmadParams.nRound = nRound; + l1TileMmadParams.kActual = kActual; + l1TileMmadParams.isKLoopFirst = (kLoopIdx == 0); + l1TileMmadParams.isKLoopLast = (kLoopIdx == kTileCount - 1); + if (kLoopIdx == kTileCount - 1) { + l1TileMmadParams.gmBlockC = gmBlockC; + l1TileMmadParams.layoutCInGm = layoutC.GetTileLayout(actualShape.GetCoordMN()); + l1TileMmadParams.callbackBeforeFixpipe = callbackBeforeFixpipe; + l1TileMmadParams.callbackAfterFixpipe = callbackAfterFixpipe; + } + + if (preloadCount < PRELOAD_STAGES) { + ++preloadCount; + } else { + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; + } + l1AListId = (l1AListId + 1 < L1A_STAGES) ? (l1AListId + 1) : 0; + l1BListId = (l1BListId + 1 < L1B_STAGES) ? (l1BListId + 1) : 0; + } + } + + CATLASS_DEVICE + void SynchronizeBlock() + { + while (preloadCount > 0) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; + --preloadCount; + } + } + +private: + struct L1TileMmadParams { + uint32_t l1AListId; + uint32_t l1BListId; + uint32_t mRound; + uint32_t nRound; + uint32_t kActual; + bool isKLoopFirst; + bool isKLoopLast; + AscendC::GlobalTensor gmBlockC; + LayoutC layoutCInGm; + Callback callbackBeforeFixpipe; + Callback callbackAfterFixpipe; + + CATLASS_DEVICE + L1TileMmadParams() = default; + }; + + CATLASS_DEVICE + void InitL1(Arch::Resource &resource, uint32_t l1BufAddrStart) + { + uint32_t l1AOffset = l1BufAddrStart; + for (uint32_t i = 0; i < L1A_STAGES; ++i) { + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1A_TILE_SIZE * i); + l1AEventList[i] = i; + AscendC::SetFlag(l1AEventList[i]); + } + uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * L1A_STAGES; + for (uint32_t i = 0; i < L1B_STAGES; ++i) { + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_TILE_SIZE * i); + l1BEventList[i] = i + L1A_STAGES; + AscendC::SetFlag(l1BEventList[i]); + } + } + + CATLASS_DEVICE + void InitL0A(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_TILE_SIZE * i); + l0AEventList[i] = i; + AscendC::SetFlag(l0AEventList[i]); + } + } + + CATLASS_DEVICE + void InitL0B(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_TILE_SIZE * i); + l0BEventList[i] = i + L0A_STAGES; + AscendC::SetFlag(l0BEventList[i]); + } + } + + CATLASS_DEVICE + void InitL0C(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + l0CTensorList[i] = resource.l0CBuf.template GetBufferByByte(L0C_TILE_SIZE * i); + l0CEventList[i] = i; + AscendC::SetFlag(l0CEventList[i]); + } + } + + CATLASS_DEVICE + void L1TileMmad(L1TileMmadParams const ¶ms) + { + uint32_t mPartLoop = CeilDiv(params.mRound); + uint32_t nPartLoop = CeilDiv(params.nRound); + uint32_t kPartLoop = CeilDiv(params.kActual); + auto &l1ATensor = l1ATensorList[params.l1AListId]; + auto &l1BTensor = l1BTensorList[params.l1BListId]; + + auto &l0CTensor = l0CTensorList[l0CListId]; + LayoutCInL0 layoutCInL0 = LayoutCInL0::MakeLayoutInL0C(MakeCoord(params.mRound, params.nRound)); + + if constexpr (!ENABLE_UNIT_FLAG) { + if (params.isKLoopFirst) { + AscendC::WaitFlag(l0CEventList[l0CListId]); + } + } + + for (uint32_t mPartIdx = 0; mPartIdx < mPartLoop; ++mPartIdx) { + uint32_t mPartActual = + (mPartIdx < mPartLoop - 1) ? L0TileShape::M : (params.mRound - mPartIdx * L0TileShape::M); + + for (uint32_t kPartIdx = 0; kPartIdx < kPartLoop; ++kPartIdx) { + uint32_t kPartActual = + (kPartIdx < kPartLoop - 1) ? L0TileShape::K : (params.kActual - kPartIdx * L0TileShape::K); + + auto &l0ATile = l0ATensorList[l0AListId]; + auto layoutAInL0 = LayoutAInL0::template MakeLayout(mPartActual, kPartActual); + auto l1AOffset = MakeCoord(mPartIdx, kPartIdx) * L0TileShape::ToCoordMK(); + auto l1ATile = l1ATensor[L1A_LAYOUT.GetOffset(l1AOffset)]; + + AscendC::WaitFlag(l0AEventList[l0AListId]); + if ((mPartIdx == 0) && (kPartIdx == 0)) { + AscendC::WaitFlag(l1AEventList[params.l1AListId]); + } + copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, L1A_LAYOUT); + if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) { + AscendC::SetFlag(l1AEventList[params.l1AListId]); + } + + for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; ++nPartIdx) { + uint32_t nPartActual = + (nPartIdx < nPartLoop - 1) ? L0TileShape::N : (params.nRound - nPartIdx * L0TileShape::N); + + auto &l0BTile = l0BTensorList[l0BListId]; + auto layoutBInL0 = LayoutBInL0::template MakeLayout(kPartActual, nPartActual); + auto l1BOffset = MakeCoord(kPartIdx, nPartIdx) * L0TileShape::ToCoordKN(); + auto l1BTile = l1BTensor[L1B_LAYOUT.GetOffset(l1BOffset)]; + + AscendC::WaitFlag(l0BEventList[l0BListId]); + if ((kPartIdx == 0) && (nPartIdx == 0)) { + AscendC::WaitFlag(l1BEventList[params.l1BListId]); + } + copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, L1B_LAYOUT); + if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + AscendC::SetFlag(l1BEventList[params.l1BListId]); + } + + AscendC::SetFlag(EVENT_ID0); + + auto l0COffset = MakeCoord(mPartIdx, nPartIdx) * L0TileShape::ToCoordMN(); + auto l0CTile = l0CTensor[layoutCInL0.GetOffset(l0COffset)]; + + AscendC::WaitFlag(EVENT_ID0); + // If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0 + bool initC = (params.isKLoopFirst && (kPartIdx == 0)); + // If the unit flag is enabled, the unit flag is set according to the calculation progress + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if (params.isKLoopLast && (mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1) && + (nPartIdx == nPartLoop - 1)) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); + + AscendC::SetFlag(l0BEventList[l0BListId]); + l0BListId = (l0BListId + 1 < L0B_STAGES) ? (l0BListId + 1) : 0; + } + AscendC::SetFlag(l0AEventList[l0AListId]); + l0AListId = (l0AListId + 1 < L0A_STAGES) ? (l0AListId + 1) : 0; + } + } + + if (params.isKLoopLast) { + auto layoutCInGm = params.layoutCInGm; + + params.callbackBeforeFixpipe(); + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(l0CEventList[l0CListId]); + AscendC::WaitFlag(l0CEventList[l0CListId]); + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0); + AscendC::SetFlag(l0CEventList[l0CListId]); + } else { + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0, 0b11); + } + l0CListId = (l0CListId + 1 < L0C_STAGES) ? (l0CListId + 1) : 0; + + params.callbackAfterFixpipe(); + } + } + + AscendC::LocalTensor l1ATensorList[L1A_STAGES]; + AscendC::LocalTensor l1BTensorList[L1B_STAGES]; + int32_t l1AEventList[L1A_STAGES]; + int32_t l1BEventList[L1B_STAGES]; + uint32_t l1AListId{0}; + uint32_t l1BListId{0}; + + AscendC::LocalTensor l0ATensorList[L0A_STAGES]; + int32_t l0AEventList[L0A_STAGES]; + uint32_t l0AListId{0}; + + AscendC::LocalTensor l0BTensorList[L0B_STAGES]; + int32_t l0BEventList[L0B_STAGES]; + uint32_t l0BListId{0}; + + AscendC::LocalTensor l0CTensorList[L0C_STAGES_]; + int32_t l0CEventList[L0C_STAGES_]; + uint32_t l0CListId{0}; + + L1TileMmadParams l1TileMmadParamsList[PRELOAD_STAGES]; + uint32_t l1TileMmadParamsId{0}; + uint32_t preloadCount{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; + + bool isFirstLoad{true}; + AscendC::GlobalTensor lastGmBlockA; +}; + +} // namespace Catlass::Gemm::Block diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/gemm/dispatch_policy.h b/csrc/deepep/ops2/utils/op_kernel/operator/gemm/dispatch_policy.h new file mode 100644 index 00000000..587f9a85 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/gemm/dispatch_policy.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "../catlass/catlass/gemm/dispatch_policy.hpp" + +namespace Catlass::Gemm { + +template +struct MmadAtlasA2PreloadAsyncWithCallbackResidentA : public Catlass::Gemm::MmadAtlasA2Async { + static constexpr uint32_t PRELOAD_STAGES = PRELOAD_STAGES_; // Stages of emitting load instruction in advance + static constexpr uint32_t L1A_STAGES = L1A_STAGES_; + static constexpr uint32_t L1B_STAGES = L1B_STAGES_; + static constexpr uint32_t L0A_STAGES = L0A_STAGES_; + static constexpr uint32_t L0B_STAGES = L0B_STAGES_; + static constexpr uint32_t L0C_STAGES = L0C_STAGES_; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; + static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; +}; + +} // namespace Catlass::Gemm diff --git a/csrc/deepep/ops2/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h b/csrc/deepep/ops2/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h new file mode 100644 index 00000000..85bf1157 --- /dev/null +++ b/csrc/deepep/ops2/utils/op_kernel/operator/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h @@ -0,0 +1,1962 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + * Description: FusedDeepMoe operator kernel function implementation file + * Author: WANG Qiankun + * Create: 2025-07-19 + * Note: + * History: 2025-07-19 create FusedDeepMoe operator kernel function implementation file + */ +#pragma once +#include "../../catlass/catlass/catlass.hpp" +#include "../../catlass/catlass/arch/cross_core_sync.hpp" +#include "../../catlass/catlass/arch/resource.hpp" +#include "../../catlass/catlass/coord.hpp" +#include "../../catlass/catlass/detail/callback.hpp" +#include "../../catlass/catlass/gemm_coord.hpp" +#include "../../catlass/catlass/matrix_coord.hpp" +#include "../../catlass/catlass/epilogue/tile/tile_swizzle.hpp" +#include "../../catlass/catlass/epilogue/tile/tile_copy.hpp" + +#include "../../../../../op_kernel/fused_deep_moe_base.h" + +constexpr uint32_t tokenLength = 7168; +constexpr uint32_t axisK_ = 8; +constexpr uint32_t STATE_OFFSET = 512; +constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024; +constexpr uint64_t STATE_WIN_OFFSET = 900 * 1024; +constexpr uint64_t GROUP_TOKEN_NUM_OFFSET = 932 * 1024; +constexpr uint64_t SOFT_SYNC_OFFSET = 940 * 1024; +constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024; +constexpr uint32_t SUM_TMP_TENSOR_SIZE = 1024; +constexpr uint32_t UB_ALIGN = 32; +constexpr uint32_t TOKEN_EXTRA_SPACE = 512; +constexpr uint32_t INT32_COUNT_PER_BLOCK = 8; +constexpr uint32_t SOFT_SYNC_SPACE_SIZE = 512; +constexpr uint32_t COMP_AIV_CORE_NUM = 24; // 24 AIV 做deq-swiglu计算,当前不支持自己调整 +constexpr uint32_t SEND_AIV_CORE_NUM = 48; // 单卡单专家时全部核发送/接收,多专家时砍半 +constexpr uint32_t RECV_AIV_CORE_NUM = 48; // 单卡单专家时全部核发送/接收,多专家时砍半 +constexpr int64_t LOOP_TMP_SIZE = 4096; // 计算地址偏移优化使用空间 +constexpr int32_t SUB_AIV_NUM = 2; // 1C配2V,即1个cube搭配两个vector +constexpr int32_t ODD_EVEN_BASE = 2; // 判断奇偶的基数 +constexpr int32_t BUFFER_NUM = 2; +constexpr int32_t GATHER_SECOND_NUM = 2; +constexpr uint32_t OPT_RANK_OFFSET = 512; // NPLB优化变量 + +#define CEIL_UP(x) ((x + UB_ALIGN - 1) / UB_ALIGN * UB_ALIGN) +#define CEIL(x, y) (((x) + (y - 1)) / (y)) +#define UB_BLOCK_SIZE (32) +#define GET_WIND_STATE_ADDR_BY_RANK_ID(rankId) \ + (((epRankId == rankId) \ + ? ((GM_ADDR)(winContext_->localWindowsExp)) \ + : ((GM_ADDR)(((HcclRankRelationResV2 *)(winContext_->remoteRes[rankId].nextDevicePtr))->windowsExp))) + \ + dataState * WIN_STATE_OFFSET) +#define GET_WIND_ADDR_BY_RANK_ID(rankId) \ + (((epRankId == rankId) \ + ? ((GM_ADDR)(winContext_->localWindowsIn)) \ + : ((GM_ADDR)(((HcclRankRelationResV2 *)(winContext_->remoteRes[rankId].nextDevicePtr))->windowsIn))) + \ + winDataSizeOffset + rankId * OPT_RANK_OFFSET) +#define TOKEN_FLAG_1 (0x55555555) +#define TOKEN_FLAG_2 (0x33333333) +#define V_TO_C_FLAG_1 (0x03030303) +#define V_TO_C_FLAG_2 (0x05050505) +#define AIC_STATE_SPACE_IDNEX (48) +#define AIV_STATE_SPACE_IDNEX (72) +#define CV_FLAG_INDEX 0 +#define GROUP_ID_INDEX 1 +#define PRE_COUNT_INDEX 2 +#define SELF_COUNT_INDEX 3 +#define TOTAL_COUNT_INDEX 4 +#define GROUP_TOKEN_COUNT 3 // 等于SELF_COUNT_INDEX +#define GROUP_INFO_SIZE 8 + +#define REACH_STEP_1_SEND_COUNT +#define REACH_STEP_2_SEND_TOKEN +#define REACH_STEP_3_RECV_COUNT +#define REACH_STEP_4_RECV_TOKEN +#define REACH_STEP_5_WAIT_RECV_CORE +#define REACH_STEP_6_GMM1_DEQ_SWIGLU +#define REACH_STEP_7_UPDATE_INFO +#define REACH_STEP_8_QUANT + +#define SEND_TOKEN_RETURN // 这个宏好像比较影响性能,待确认 + +namespace Catlass::Gemm::Kernel { + +template +class BlockQuant +{ +public: + using ElementInput = float; + using LayoutInput = layout::RowMajor; + using ElementDequantScale = float; + using LayoutDequantScale = layout::VectorLayout; + using ElementOutput = int8_t; + using LayoutOutput = layout::RowMajor; + + using InputType = GemmType; + using DequantScaleType = GemmType; + using OutputType = GemmType; + + constexpr static uint32_t TILE_ROW = 8; + constexpr static uint32_t TILE_COLUMN = 2048; + constexpr static uint32_t HALF_TILE_COLUMN = 1024; + using TileShape = MatrixShape; + using HalfTileShape = MatrixShape; + + using EpilogueTileSwizzle = Epilogue::Tile::EpilogueHorizontalTileSwizzle; + + struct Params { + __gm__ ElementInput *ptrInput{nullptr}; + LayoutInput layoutInput; + __gm__ ElementDequantScale *ptrDequantScale{nullptr}; + LayoutDequantScale layoutDequantScale; + __gm__ ElementOutput *ptrOutput{nullptr}; + LayoutOutput layoutOutput; + + CATLASS_DEVICE + Params() {}; + + CATLASS_DEVICE + Params(__gm__ ElementInput *ptrInput_, LayoutInput const &layoutInput_, + __gm__ ElementDequantScale *ptrQuantScale_, LayoutDequantScale const &layoutQuantScale_, + __gm__ ElementOutput *ptrOutput_, LayoutOutput const layoutOutput_) + : ptrInput(ptrInput_), + layoutInput(layoutInput_), + ptrDequantScale(ptrQuantScale_), + layoutDequantScale(layoutQuantScale_), + ptrOutput(ptrOutput_), + layoutOutput(layoutOutput_) + {} + }; + + CATLASS_DEVICE + BlockQuant(Arch::Resource const &resource, Params const ¶ms_) : params(params_) + { + int64_t ubOffset = 0; + ubInput = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementInput); + ubDequantScale = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(ElementDequantScale); + ubOutput = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementOutput); + + ubAbs = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubMax = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += HalfTileShape::COUNT * sizeof(float); + ubReduceMax = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(float); + ubQuantScale = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(float); + ubInputTmp = ubAbs; + ubQuantF32 = ubAbs; + ubQuantS32 = ubAbs.ReinterpretCast(); + ubQuantF16 = ubAbs.ReinterpretCast(); + + AscendC::SetFlag(0); + AscendC::SetFlag(0); + AscendC::SetFlag(1); + } + + CATLASS_DEVICE + ~BlockQuant() + { + AscendC::WaitFlag(0); + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + } + + CATLASS_DEVICE + void operator()(MatrixCoord const &blockShape, MatrixCoord const &blockCoord, MatrixCoord const &actualBlockShape) + { + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmInput; + gmInput.SetGlobalBuffer(params.ptrInput); + AscendC::GlobalTensor gmDequantScale; + gmDequantScale.SetGlobalBuffer(params.ptrDequantScale); + AscendC::GlobalTensor gmOutput; + gmOutput.SetGlobalBuffer(params.ptrOutput); + + auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); + auto ubHalfTileStride = MakeCoord(static_cast(HalfTileShape::COLUMN), 1L); + auto tileShape = TileShape::ToCoord(); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = AscendC::GetSubBlockIdx(); + uint32_t subblockNum = AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileInput = gmInput[params.layoutInput.GetOffset(tileOffset)]; + auto layoutGmTileInput = params.layoutInput.GetTileLayout(actualTileShape); + + layout::RowMajor layoutUbInput{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(0); + copyGmToUbInput(ubInput, gmTileInput, layoutUbInput, layoutGmTileInput); + AscendC::SetFlag(0); + + AscendC::WaitFlag(0); + AscendC::Abs(ubAbs, ubInput, TileShape::COUNT); + AscendC::PipeBarrier(); + + for (uint32_t rowIdx = 0; rowIdx < HalfTileShape::ROW; ++rowIdx) { + AscendC::Max(ubMax[rowIdx * HalfTileShape::COLUMN], ubAbs[rowIdx * TileShape::COLUMN], + ubAbs[rowIdx * TileShape::COLUMN + HalfTileShape::COLUMN], HalfTileShape::COLUMN); + } + + AscendC::PipeBarrier(); + AscendC::Muls(ubInputTmp, ubInput, 127.f, TileShape::COUNT); + + constexpr uint32_t elementPerBlk = BYTE_PER_BLK / sizeof(float); + constexpr int32_t mask = 64; + + AscendC::BinaryRepeatParams maxParams; + maxParams.dstBlkStride = HalfTileShape::COLUMN / elementPerBlk; + maxParams.src0BlkStride = HalfTileShape::COLUMN / elementPerBlk; + maxParams.src1BlkStride = HalfTileShape::COLUMN / elementPerBlk; + maxParams.dstRepStride = 1; + maxParams.src0RepStride = 1; + maxParams.src1RepStride = 1; + constexpr uint32_t colNumPerCompute = BYTE_PER_VECTOR_FRACTAL / sizeof(float); + uint32_t reduceWidth = HalfTileShape::COLUMN; + while (reduceWidth > (BLK_NUM_PER_VECTOR_FRACTAL * BYTE_PER_BLK / sizeof(float))) { + reduceWidth >>= 1; + AscendC::Max(ubMax, ubMax, ubMax[reduceWidth], mask, reduceWidth / elementPerBlk, maxParams); + AscendC::PipeBarrier(); + } + + AscendC::WholeReduceMax(ubReduceMax, ubMax, mask, HalfTileShape::ROW, 1, 1, + HalfTileShape::COLUMN / elementPerBlk, AscendC::ReduceOrder::ORDER_ONLY_VALUE); + AscendC::SetFlag(0); + AscendC::SetFlag(0); + AscendC::PipeBarrier(); + + AscendC::WaitFlag(0); + AscendC::Muls(ubDequantScale, ubReduceMax, 1.0f / 127.0f, TileShape::ROW); + AscendC::SetFlag(0); + + auto dequantScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); + auto dequantScaleTileShape = actualTileShape.template GetCoordByAxis<0>(); + + auto gmTileDequantScale = gmDequantScale[params.layoutDequantScale.GetOffset(dequantScaleTileOffset)]; + auto layoutGmTileDequantScale = params.layoutDequantScale.GetTileLayout(dequantScaleTileShape); + + auto layoutUbDequantScale = + LayoutDequantScale::template MakeLayoutInUb(dequantScaleTileShape); + + AscendC::WaitFlag(0); + copyUbToGmDequantScale(gmTileDequantScale, ubDequantScale, layoutGmTileDequantScale, layoutUbDequantScale); + AscendC::SetFlag(0); + + AscendC::WaitFlag(0); + for (uint32_t rowIdx = 0; rowIdx < TileShape::ROW; ++rowIdx) { + AscendC::Muls(ubQuantF32[rowIdx * TileShape::COLUMN], ubInputTmp[rowIdx * TileShape::COLUMN], + 1.f / ubReduceMax.GetValue(rowIdx), TileShape::COLUMN); + } + + AscendC::PipeBarrier(); + AscendC::Cast(ubQuantS32, ubQuantF32, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::PipeBarrier(); + AscendC::SetDeqScale(static_cast(1.0)); + AscendC::Cast(ubQuantF16, ubQuantS32, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::PipeBarrier(); + + AscendC::WaitFlag(1); + AscendC::Cast(ubOutput, ubQuantF16, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(1); + + auto gmTileOutput = gmOutput[params.layoutOutput.GetOffset(tileOffset)]; + auto layoutGmTileOutput = params.layoutOutput.GetTileLayout(actualTileShape); + + LayoutOutput layoutUbOutput{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(1); + copyUbToGmOutput(gmTileOutput, ubOutput, layoutGmTileOutput, layoutUbOutput); + AscendC::SetFlag(1); + } + } + +private: + Params params; + + AscendC::LocalTensor ubInput; + AscendC::LocalTensor ubDequantScale; + AscendC::LocalTensor ubOutput; + + AscendC::LocalTensor ubAbs; + AscendC::LocalTensor ubMax; + AscendC::LocalTensor ubReduceMax; + AscendC::LocalTensor ubQuantScale; + AscendC::LocalTensor ubQuantScaleBrcb; + AscendC::LocalTensor ubInputTmp; + AscendC::LocalTensor ubQuantF32; + AscendC::LocalTensor ubQuantS32; + AscendC::LocalTensor ubQuantF16; + + Epilogue::Tile::CopyGm2Ub copyGmToUbInput; + Epilogue::Tile::CopyUb2Gm copyUbToGmDequantScale; + Epilogue::Tile::CopyUb2Gm copyUbToGmOutput; +}; + +__aicore__ inline static void EncreaseSyncFlag(__gm__ uint8_t *flagAddr, uint8_t idx) +{ + // flag++,类似set flag + AscendC::PipeBarrier(); + AscendC::GlobalTensor global; + global.SetGlobalBuffer(flagAddr + idx * SOFT_SYNC_SPACE_SIZE); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + global); + __asm__ __volatile__(""); + uint8_t value = global.GetValue(0); + global.SetValue(0, value + 1); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + global); + __asm__ __volatile__(""); + AscendC::PipeBarrier(); +} + +__aicore__ inline static void CheckSyncFlag(__gm__ uint8_t *flagAddr, uint8_t idx, uint32_t target) +{ + // 查看flag,类似wait flag + AscendC::PipeBarrier(); + AscendC::GlobalTensor global; + global.SetGlobalBuffer(flagAddr + idx * SOFT_SYNC_SPACE_SIZE); + while (true) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(global); + __asm__ __volatile__(""); + uint8_t value = global.GetValue(0); + if (value >= target) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(global); + __asm__ __volatile__(""); + break; + } + } + AscendC::PipeBarrier(); +} + +template +class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using ElementDequantScale = typename BlockQuant::ElementDequantScale; + using LayoutDequantScale = typename BlockQuant::LayoutDequantScale; + using ElementOutput = typename BlockQuant::ElementOutput; + using LayoutOutput = typename BlockQuant::LayoutOutput; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; + using ElementGroupList = ElementGroupList_; + + using XType = XType_; + + // Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList_ *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementOutput *ptrOutput; + LayoutOutput layoutOutput; + __gm__ ElementDequantScale *ptrDequantScale; + LayoutDequantScale layoutDequantScale; + GM_ADDR ptrWorkspace; + GM_ADDR gmX; + GM_ADDR debugGm; + GM_ADDR gmexpertIds; + + GM_ADDR gmExpandIdx; + GM_ADDR gmEpSendCount; + GM_ADDR gmResvered; + + uint32_t epRankSize; + uint32_t epRankId; + uint32_t moeExpertNum; + uint32_t moeExpertNumPerRank; + uint32_t sharedExpertNum; + uint32_t sharedExpertRankNum; + uint32_t quantMode; + uint32_t globalBs; + uint32_t bs; + // Methods + CATLASS_DEVICE + Params() {} + + CATLASS_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, + LayoutA const &layoutA_, GM_ADDR ptrB_, LayoutB const &layoutB_, GM_ADDR ptrScale_, + LayoutScale const &layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_, + GM_ADDR ptrDequantScale_, LayoutDequantScale const &layoutDequantScale_, GM_ADDR ptrWorkspace_, + GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_, + GM_ADDR gmResvered_, uint32_t epRankSize_, uint32_t epRankId_, uint32_t moeExpertNum_, + uint32_t moeExpertNumPerRank_, uint32_t sharedExpertNum_, uint32_t sharedExpertRankNum_, + uint32_t quantMode_, uint32_t globalBs_, uint32_t bs_) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrOutput(reinterpret_cast<__gm__ ElementOutput *>(ptrOutput_)), + layoutOutput(layoutOutput_), + ptrDequantScale(reinterpret_cast<__gm__ ElementDequantScale *>(ptrDequantScale_)), + layoutDequantScale(layoutDequantScale_), + ptrWorkspace(ptrWorkspace_), + gmX(gmX_), + debugGm(debugGm_), + gmexpertIds(gmexpertIds_), + gmExpandIdx(gmExpandIdx_), + gmEpSendCount(gmEpSendCount_), + gmResvered(gmResvered_), + epRankSize(epRankSize_), + epRankId(epRankId_), + moeExpertNum(moeExpertNum_), + moeExpertNumPerRank(moeExpertNumPerRank_), + sharedExpertNum(sharedExpertNum_), + sharedExpertRankNum(sharedExpertRankNum_), + quantMode(quantMode_), + globalBs(globalBs_), + bs(bs_) + {} + }; + + // Methods + CATLASS_DEVICE + GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + aicIdx = AscendC::GetBlockIdx(); + subBlockNum = AscendC::GetSubBlockNum(); + aiCoreGroupNum = AscendC::GetBlockNum(); + aicNum = aiCoreGroupNum; + aicStateGlobalCoreIdx = AIC_STATE_SPACE_IDNEX + aicIdx; + moeExpertNumPerRank = params.moeExpertNumPerRank; + isShareExpert = (params.epRankId < params.sharedExpertRankNum); + localExpertNum = isShareExpert ? 1 : moeExpertNumPerRank; + // 单卡单专家48发48收 + recvCoreNum = RECV_AIV_CORE_NUM; + // 单卡多专家24收24发 + if (moeExpertNumPerRank > 1) { + recvCoreNum = RECV_AIV_CORE_NUM / SUB_AIV_NUM; + } + uint32_t coreNumPerGroup = recvCoreNum / localExpertNum; // 这里假设可以整除 + winContext_ = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); + + // 更新状态,影响CV交互使用的信号值 + statusDataSpaceGm = (GM_ADDR)(winContext_->localWindowsExp); + AscendC::GlobalTensor selfDataStatusTensor; + selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aicStateGlobalCoreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + cvDataState = selfDataStatusTensor(aicStateGlobalCoreIdx * UB_ALIGN); + if (cvDataState == 0) { + selfDataStatusTensor(aicStateGlobalCoreIdx * UB_ALIGN) = 1; + vToCFlag = V_TO_C_FLAG_1; + } else { + selfDataStatusTensor(aicStateGlobalCoreIdx * UB_ALIGN) = 0; + vToCFlag = V_TO_C_FLAG_2; + } + + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer(params.ptrB); + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * aicNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t stageUsed = 0; + uint32_t startCoreIdx = 0; + AscendC::GlobalTensor groupTokenNumStateTensor; + aicSetFunc1 = {statusDataSpaceGm + SOFT_SYNC_OFFSET, + static_cast(aicNum + AscendC::GetBlockIdx())}; // AIV等待的信息在24~48 + uint32_t target = 1; + for (uint32_t groupIdx = 0; groupIdx < localExpertNum; ++groupIdx) { + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET) + + groupIdx * GROUP_INFO_SIZE); + // 等待AIV的token收齐信号后,再往下走 + while (true) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(groupTokenNumStateTensor); + __asm__ __volatile__(""); + if (groupTokenNumStateTensor.GetValue(0) == coreNumPerGroup * vToCFlag) { + break; + } + } + + uint32_t currentM = groupTokenNumStateTensor.GetValue(GROUP_TOKEN_COUNT); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx = ((aicIdx < startCoreIdx) ? (aicIdx + aicNum) : aicIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += aicNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + // 使用软同步 + Callback callbackBeforeFixpipe{}; + if (stageUsed == WORKSPACE_STAGES) { + aicWaitFunc1 = {statusDataSpaceGm + SOFT_SYNC_OFFSET, static_cast(AscendC::GetBlockIdx()), + target}; // AIC等待的信号在前24个 + target += 1; + callbackBeforeFixpipe = MakeCallback(&aicWaitFunc1); + } else { + ++stageUsed; + } + Callback callbackAfterFixpipe = MakeCallback(&aicSetFunc1); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{(stageId * aicNum + aicIdx) * L1TileShape::M, 0}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe); + } else { + callbackBeforeFixpipe(); + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape); + callbackAfterFixpipe(); + } + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % aicNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + while (stageUsed > 0) { + uint32_t aivComputeStageId = + (stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed); + target += 1; // 追平AIV多余的软同步 + --stageUsed; + } + } + + CATLASS_DEVICE + void CalExpandxIdx(int32_t dstExpertId, uint32_t tokenIndex, int32_t &curExpertCnt, int64_t ubOffset) + { + // 使用AIV计算发送到对端的偏移量 + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor dstExpIdTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + subUbOffset += LOOP_TMP_SIZE; + AscendC::LocalTensor subExpIdTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + subUbOffset += LOOP_TMP_SIZE; + AscendC::LocalTensor workLocalTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + subUbOffset += LOOP_TMP_SIZE; + AscendC::Duplicate(dstExpIdTensor_, dstExpertId, tokenIndex); + AscendC::PipeBarrier(); + AscendC::Sub(subExpIdTensor_, expertIdsTensor_, dstExpIdTensor_, tokenIndex); + AscendC::PipeBarrier(); + AscendC::LocalTensor tmpFp32 = subExpIdTensor_.ReinterpretCast(); + AscendC::LocalTensor tmpoutFp32 = dstExpIdTensor_.ReinterpretCast(); + AscendC::Abs(tmpoutFp32, tmpFp32, tokenIndex); + AscendC::PipeBarrier(); + AscendC::Mins(subExpIdTensor_, dstExpIdTensor_, 1, tokenIndex); + AscendC::PipeBarrier(); + AscendC::ReduceSum(tmpoutFp32, tmpFp32, workLocalTensor_, tokenIndex); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + int32_t curOtherExpertCnt = dstExpIdTensor_(0); + if (tokenIndex > curOtherExpertCnt) { + curExpertCnt = tokenIndex - curOtherExpertCnt; + } + } + + CATLASS_DEVICE + void CalAndSendTokenCount() + { + // 计算发送token的数量,并且发送出去 + uint32_t totalExpertNum = sharedExpertRankNum + moeExpertNum; + uint32_t sendCountExpertNum = totalExpertNum / sendCoreNum; // 每个aiv需要处理的专家数 + uint32_t remainderRankNum = totalExpertNum % sendCoreNum; + uint32_t startExpertId = sendCountExpertNum * sendCoreIdx; // sharedExpertRankNum, 每个aiv发送的起始rankid + if (sendCoreIdx < remainderRankNum) { // 前remainderRankNum个aiv需要多发1个卡的数据 + sendCountExpertNum += 1; + startExpertId += sendCoreIdx; + } else { + startExpertId += remainderRankNum; + } + uint32_t endExpertId = startExpertId + sendCountExpertNum; + if (startExpertId >= totalExpertNum) { + return; + } + // 计算count及偏移:开始 + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(CEIL(expertCntUp, INT32_COUNT_PER_BLOCK) * INT32_COUNT_PER_BLOCK * UB_BLOCK_SIZE); + AscendC::Duplicate(statusTensor_, (int32_t)0, + expertCntUp * INT32_COUNT_PER_BLOCK); // 先清零再赋值,清零一定要做 + if (state == 0) { + // 一次性操作256字节,也是64个int32_t,每8个数将首个设置为0x3F800000,即浮点数的1.0 + uint64_t mask[2] = {0x101010101010101, 0}; + AscendC::PipeBarrier(); + // 这里原版代码有bug,block数量不是8的倍数时,后面的尾巴没法更新 + AscendC::Duplicate(statusTensor_, 0x3F800000, mask, CEIL(expertCntUp, 8), 1, 8); + } + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + for (uint32_t curExpertId = startExpertId; curExpertId < endExpertId; ++curExpertId) { + if (curExpertId < sharedExpertRankNum) { + continue; + } + int32_t curExpertCnt = 0; + int32_t dstExpertId = curExpertId - sharedExpertRankNum; + CalExpandxIdx(dstExpertId, expertIdsCnt, curExpertCnt, ubOffset); + int32_t cntPosIndex = curExpertId * 8 + 1; // 8的含义为一个专家占32字节 + statusTensor_(cntPosIndex) = curExpertCnt; + } + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + AscendC::GlobalTensor rankGMTensor; + uint32_t offset = stateOffset * epRankId; + for (uint32_t rankIndex = startExpertId; rankIndex < endExpertId; ++rankIndex) { + uint32_t dstRankId = rankIndex; + if (moeExpertNumPerRank > 1 && (rankIndex >= sharedExpertRankNum)) { + dstRankId = ((rankIndex - sharedExpertRankNum) / moeExpertNumPerRank + sharedExpertRankNum); + offset = + (epRankId + (rankIndex - sharedExpertRankNum) % moeExpertNumPerRank * epRankSize) * stateOffset; + } + GM_ADDR rankGM = (__gm__ uint8_t *)(GET_WIND_STATE_ADDR_BY_RANK_ID(dstRankId) + offset); // 计算地址偏移 + rankGMTensor.SetGlobalBuffer((__gm__ int32_t *)rankGM); + AscendC::DataCopy(rankGMTensor, statusTensor_[rankIndex * 8], 8UL); // 8时数据大小,按32对齐拷贝 + } + } + + CATLASS_DEVICE + void QuantToken(AscendC::LocalTensor &xInTensor, AscendC::LocalTensor &yInt8Tensor, int64_t ubOffset) + { + // 量化token的函数,这里UB空间基本用完就释放了,所以在内部计算UB偏移 + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor xFp32TmpTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(tokenLength * sizeof(float)); + AscendC::LocalTensor xFp32AbsTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(tokenLength * sizeof(float)); + AscendC::LocalTensor xRowMaxTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor ytmpInt32Tensor = xFp32TmpTensor.template ReinterpretCast(); + AscendC::LocalTensor yHalfTensor = xFp32TmpTensor.template ReinterpretCast(); + AscendC::LocalTensor yFp32Tensor = yInt8Tensor.template ReinterpretCast(); + AscendC::LocalTensor yInt32Tensor = yInt8Tensor.template ReinterpretCast(); + + AscendC::Cast(xFp32TmpTensor, xInTensor, AscendC::RoundMode::CAST_NONE, tokenLength); + AscendC::PipeBarrier(); + AscendC::Abs(xFp32AbsTensor, xFp32TmpTensor, tokenLength); + AscendC::PipeBarrier(); + AscendC::ReduceMax(xRowMaxTensor, xFp32AbsTensor, xFp32AbsTensor, tokenLength, false); + AscendC::PipeBarrier(); + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + float dynamicQuantScale = float(127.0) / xRowMaxTensor.GetValue(0); + yFp32Tensor.SetValue(tokenLength / sizeof(float), float(1.0) / dynamicQuantScale); + yInt32Tensor.SetValue(tokenLength / sizeof(int32_t) + 1, tokenFlag); + AscendC::SetFlag(0); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + AscendC::Muls(xFp32TmpTensor, xFp32TmpTensor, dynamicQuantScale, tokenLength); + AscendC::PipeBarrier(); + AscendC::Cast(ytmpInt32Tensor, xFp32TmpTensor, AscendC::RoundMode::CAST_RINT, tokenLength); + AscendC::PipeBarrier(); + AscendC::Cast(yHalfTensor, ytmpInt32Tensor, AscendC::RoundMode::CAST_ROUND, tokenLength); + AscendC::PipeBarrier(); + AscendC::Cast(yInt8Tensor, yHalfTensor, AscendC::RoundMode::CAST_TRUNC, tokenLength); + } + + CATLASS_DEVICE + void SendToShareExprt(GM_ADDR gmX, GM_ADDR gmX1, GM_ADDR gmX1Scale) + { + // 给共享专家发送token + uint32_t newAivId = sendCoreIdx - sendToMoeAivNum; + uint32_t sendTokenNum = axisBS / sendToShareAivNum; // 每个aiv需要发送的token数 + uint32_t remainderTokenNum = expertIdsCnt % sendToShareAivNum; // 余数 + uint32_t startTokenId = sendTokenNum * newAivId; // 每个aiv发送时的起始rankid + if (newAivId < remainderTokenNum) { // 前remainderRankNum个aiv需要多发1个卡的数据 + sendTokenNum += 1; + startTokenId += newAivId; + } else { + startTokenId += remainderTokenNum; + } + uint32_t endTokenId = startTokenId + sendTokenNum; +#ifdef SEND_TOKEN_RETURN + if (startTokenId >= axisBS) { + return; + } +#endif + AscendC::LocalTensor xInTensor[BUFFER_NUM]; + AscendC::LocalTensor yInt8Tensor[BUFFER_NUM]; + AscendC::LocalTensor yFp32Tensor[BUFFER_NUM]; + + AscendC::GlobalTensor srcWinGMTensor; // token输入 + srcWinGMTensor.SetGlobalBuffer((__gm__ XType *)gmX); + + xInTensor[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(tokenLength * sizeof(XType)); + xInTensor[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(tokenLength * sizeof(XType)); + yInt8Tensor[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(axisHCommu * sizeof(int8_t)); + yInt8Tensor[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(axisHCommu * sizeof(int8_t)); + AscendC::GlobalTensor dstWinGMTensor; // token输出 + AscendC::GlobalTensor expandXOutGlobal; + expandXOutGlobal.SetGlobalBuffer((__gm__ int8_t *)(gmX1)); + AscendC::GlobalTensor dynamicScalesOutGMTensor_; + dynamicScalesOutGMTensor_.SetGlobalBuffer((__gm__ float *)(gmX1Scale)); +#ifndef SEND_TOKEN_RETURN + if (startTokenId < axisBS) { +#endif + // 输入输出开double buffer + AscendC::SetFlag(0); // MTE2等MTE3 + AscendC::SetFlag(1); // MTE2等MTE3 + AscendC::SetFlag(0); + AscendC::SetFlag(1); + + for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) { + uint32_t index = (tokenIndex & 1) ? 0 : 1; + int32_t eventId = (tokenIndex & 1) ? 0 : 1; + // 下面的计算有点绕,目的是计算目的专家卡和偏移 + uint32_t temp = (epRankId * axisBS) / sharedExpertRankNum; + // 当前token发给哪个共享专家 + uint32_t moeOnShareRank = CEIL((tokenIndex + 1 + temp) * sharedExpertRankNum, axisBS) - 1 - epRankId; + // 发给该共享专家已经有多少token数据 + uint32_t preCnt = (moeOnShareRank + epRankId) * axisBS / sharedExpertRankNum - + epRankId * axisBS / sharedExpertRankNum; + dstWinGMTensor.SetGlobalBuffer( + (__gm__ int8_t *)(GET_WIND_ADDR_BY_RANK_ID(moeOnShareRank) + expertPerSizeOnWin * epRankId)); + + AscendC::WaitFlag(eventId); + AscendC::WaitFlag(eventId); + AscendC::DataCopy(xInTensor[index], srcWinGMTensor[tokenIndex * tokenLength], tokenLength); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + QuantToken(xInTensor[index], yInt8Tensor[index], ubOffset); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(0); + + AscendC::WaitFlag(eventId); + if (isShareExpert) { + AscendC::DataCopyExtParams dataCopyParamsFloat = {1U, sizeof(float), 0U, 0U, 0U}; + AscendC::DataCopy(expandXOutGlobal[tokenIndex * tokenLength], yInt8Tensor[index], tokenLength); + AscendC::PipeBarrier(); + AscendC::DataCopyPad(dynamicScalesOutGMTensor_[tokenIndex], + yFp32Tensor[index][tokenLength / sizeof(float)], dataCopyParamsFloat); + } else { + // 怀疑有时序问题,所以分开发送 + AscendC::DataCopy(dstWinGMTensor[(tokenIndex - preCnt) * axisHCommu], yInt8Tensor[index], + tokenLength); + AscendC::PipeBarrier(); + AscendC::DataCopy(dstWinGMTensor[(tokenIndex - preCnt) * axisHCommu + tokenLength], + yInt8Tensor[index][tokenLength], scaleParamPad); + } + AscendC::SetFlag(eventId); + AscendC::SetFlag(eventId); + } + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); +#ifndef SEND_TOKEN_RETURN + } +#endif + } + + CATLASS_DEVICE + void SendToMoeExprt(GM_ADDR gmX, GM_ADDR gmExpandIdx) + { + // 给路由专家发送token + uint32_t sendTokenNum = expertIdsCnt / sendToMoeAivNum; + uint32_t remainderTokenNum = expertIdsCnt % sendToMoeAivNum; + uint32_t startTokenId = sendTokenNum * sendCoreIdx; + if (sendCoreIdx < remainderTokenNum) { + sendTokenNum += 1; + startTokenId += sendCoreIdx; + } else { + startTokenId += remainderTokenNum; + } + uint32_t endTokenId = startTokenId + sendTokenNum; +#ifdef SEND_TOKEN_RETURN + if (startTokenId >= expertIdsCnt) { + return; + } +#else + if (startTokenId < expertIdsCnt) { +#endif + AscendC::LocalTensor expertCountTensor = (resource.ubBuf.template GetBufferByByte(ubOffset)); + ubOffset += CEIL_UP(expertIdsCnt * sizeof(int32_t)); + AscendC::Duplicate(expertCountTensor, (int32_t)0, expertIdsCnt); // 清零 + AscendC::SetFlag(1); + AscendC::WaitFlag(1); + + AscendC::LocalTensor xInTensor[BUFFER_NUM]; + AscendC::LocalTensor yInt8Tensor[BUFFER_NUM]; + AscendC::LocalTensor yFp32Tensor[BUFFER_NUM]; + + AscendC::GlobalTensor srcWinGMTensor; // token输入 + srcWinGMTensor.SetGlobalBuffer((__gm__ XType *)gmX); + + xInTensor[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(tokenLength * sizeof(XType)); + xInTensor[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(tokenLength * sizeof(XType)); + yInt8Tensor[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(axisHCommu * sizeof(int8_t)); + yInt8Tensor[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(axisHCommu * sizeof(int8_t)); + AscendC::GlobalTensor dstWinGMTensor; // token输出 + // 输入输出开double buffer + AscendC::SetFlag(0); // MTE2等MTE3 + AscendC::SetFlag(1); // MTE2等MTE3 + AscendC::SetFlag(0); + AscendC::SetFlag(1); + uint32_t sendValidTokenIndex = 0; + for (uint32_t sendGroupIndex = 0; sendGroupIndex < moeExpertNumPerRank; ++sendGroupIndex) { + for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) { + uint32_t dstExpertId = expertIdsTensor_(tokenIndex); + if ((dstExpertId % moeExpertNumPerRank) != sendGroupIndex) { // 优先发送指定专家的token + continue; + } + uint32_t index = (sendValidTokenIndex & 1) ? 0 : 1; + int32_t eventId = (sendValidTokenIndex & 1) ? 0 : 1; + sendValidTokenIndex += 1; + int32_t curExpertCnt = 0; + CalExpandxIdx(dstExpertId, tokenIndex, curExpertCnt, ubOffset); + expertCountTensor(tokenIndex - startTokenId) = curExpertCnt; + uint32_t tempRankId = dstExpertId / moeExpertNumPerRank + sharedExpertRankNum; + GM_ADDR rankGM = (__gm__ uint8_t *)(GET_WIND_ADDR_BY_RANK_ID(tempRankId) + + (expertPerSizeOnWin * (epRankId * moeExpertNumPerRank + + dstExpertId % moeExpertNumPerRank)) + + hCommuSize * curExpertCnt); + dstWinGMTensor.SetGlobalBuffer((__gm__ int8_t *)rankGM); + + AscendC::WaitFlag(eventId); + AscendC::WaitFlag(eventId); + AscendC::DataCopy(xInTensor[index], srcWinGMTensor[tokenIndex / axisK_ * tokenLength], tokenLength); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + QuantToken(xInTensor[index], yInt8Tensor[index], ubOffset); + AscendC::SetFlag(eventId); + + AscendC::WaitFlag(0); + AscendC::WaitFlag(eventId); + + // 担心有时序问题,所以分开发送 + AscendC::DataCopy(dstWinGMTensor, yInt8Tensor[index], tokenLength); + AscendC::PipeBarrier(); + AscendC::DataCopy(dstWinGMTensor[tokenLength], yInt8Tensor[index][tokenLength], scaleParamPad); + AscendC::SetFlag(eventId); + AscendC::SetFlag(eventId); + } + } + AscendC::WaitFlag(0); // MTE2等MTE3 + AscendC::WaitFlag(1); // MTE2等MTE3 + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + + AscendC::GlobalTensor expandIdxGMTensor; + expandIdxGMTensor.SetGlobalBuffer((__gm__ int32_t *)gmExpandIdx + startTokenId); + AscendC::DataCopyExtParams expertIdsCntParams = {1U, static_cast(sendTokenNum * sizeof(uint32_t)), 0U, + 0U, 0U}; + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopyPad(expandIdxGMTensor, expertCountTensor, expertIdsCntParams); +#ifndef SEND_TOKEN_RETURN + } +#endif +} + +CATLASS_DEVICE void +SendCoreFunc(GM_ADDR gmX, GM_ADDR gmExpertIds, GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmExpandIdx) +{ + ubOffset = 0; + expertIdsCnt = axisBS * axisK_; + + AscendC::GlobalTensor expertIdsGMTensor_; + expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)gmExpertIds); + expertIdsTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + ubOffset += CEIL_UP(expertIdsCnt * sizeof(int32_t)); + + AscendC::DataCopyExtParams expertIdsCntParams = {1U, static_cast(expertIdsCnt * sizeof(uint32_t)), 0U, 0U, + 0U}; + AscendC::DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + AscendC::DataCopyPad(expertIdsTensor_, expertIdsGMTensor_, expertIdsCntParams, copyPadParams); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + CalAndSendTokenCount(); + AscendC::PipeBarrier(); + if (hasShareExpert) { + sendToShareAivNum = sendCoreNum / (axisK_ + 1); // 均等分,取整 + if (sendToShareAivNum == 0) { + sendToShareAivNum = 1; + } + } + sendToMoeAivNum = sendCoreNum - sendToShareAivNum; + + AscendC::SetDeqScale((half)1.000000e+00f); + if (hasShareExpert && sendCoreIdx >= sendToMoeAivNum) { + SendToShareExprt(gmX, gmX1, gmX1Scale); + } else { + SendToMoeExprt(gmX, gmExpandIdx); + } + AscendC::PipeBarrier(); +} + +CATLASS_DEVICE +void RecvCount(int64_t ubOffset) +{ + // 接收count数据 + uint32_t recStatusNumPerCore = isShareExpert ? epRankSize : expertCntUp; + uint32_t startStatusIndex = 0; // 目前每个核都要收集所有的count + + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * UB_BLOCK_SIZE); + AscendC::LocalTensor gatherTmpTensor = (resource.ubBuf.template GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * sizeof(float)); + AscendC::LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + + AscendC::LocalTensor statusSumOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor sumTmpTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(SUM_TMP_TENSOR_SIZE); + gatherTmpTensor.SetValue(0, 1); + + uint32_t mask = 1; // gatherMask + sum 相关参数 + uint64_t rsvdCnt = 0; + AscendC::SumParams sumParams{1, recStatusNumPerCore, recStatusNumPerCore}; + float sumOfFlag = static_cast(-1.0); + float minTarget = (sumTarget * recStatusNumPerCore) - (float)0.5; + float maxTarget = (sumTarget * recStatusNumPerCore) + (float)0.5; + AscendC::DataCopyParams intriParams{static_cast(recStatusNumPerCore), 1, static_cast(15), + 0}; // srcStride为15个block + AscendC::GlobalTensor windowInstatusFp32Tensor_; + windowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)GET_WIND_STATE_ADDR_BY_RANK_ID(epRankId)); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + uint32_t preRecvTokenCount = 0; + while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) { + AscendC::DataCopy(statusFp32Tensor_, windowInstatusFp32Tensor_[startStatusIndex * stateOffset / sizeof(float)], + intriParams); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, mask, + {1, (uint16_t)recStatusNumPerCore, 1, 0}, rsvdCnt); + AscendC::PipeBarrier(); + AscendC::Sum(statusSumOutTensor, gatherMaskOutTensor, sumTmpTensor, sumParams); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + sumOfFlag = statusSumOutTensor.GetValue(0); + } +} + +CATLASS_DEVICE +void GetCumSum(int32_t startRankId, int32_t recvExpertNum, int64_t ubOffset) +{ + // 计算前缀和,目的是知道自己收到的token在output中的偏移 + int64_t subUbOffset = ubOffset; + uint32_t recStatusNumPerCore = isShareExpert ? epRankSize : expertCntUp; + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * UB_BLOCK_SIZE); + AscendC::LocalTensor gatherTmpTensor = (resource.ubBuf.template GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * sizeof(float)); + AscendC::LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + if (isShareExpert) { + for (uint32_t curSatatusExpId = 0; curSatatusExpId < sharedExpertRankNum; ++curSatatusExpId) { + int32_t curExpertCnt = (curSatatusExpId + 1 + epRankId) * axisBS / sharedExpertRankNum - + (curSatatusExpId + epRankId) * axisBS / sharedExpertRankNum; + statusTensor_((curSatatusExpId)*INT32_COUNT_PER_BLOCK + 1) = curExpertCnt; + } + } + + uint64_t rsvdCnt = 0; + gatherTmpTensor.SetValue(0, GATHER_SECOND_NUM); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, GATHER_SECOND_NUM, + {1, (uint16_t)recStatusNumPerCore, 1, 0}, rsvdCnt); + + // 这里是为ReduceSum准备所需空间,本应该计算好需要多大空间,但当前是给偏移,且用完就释放,所以就不计算了 + AscendC::LocalTensor workLocalTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + AscendC::PipeBarrier(); + AscendC::ReduceSum(gatherMaskOutTensor, gatherMaskOutTensor, workLocalTensor, + (startRankId + 1) <= recvExpertNum ? (startRankId + 1) : recvExpertNum); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); +} + +CATLASS_DEVICE +void RecvToken(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount, uint32_t &coreTokenCount, uint32_t startRankId, + uint32_t endRankId, uint32_t recvRankNumPerCore, int64_t ubOffset) +{ + // 接收token + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * UB_BLOCK_SIZE); + AscendC::LocalTensor gatherTmpTensor = (resource.ubBuf.template GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * sizeof(float)); + AscendC::LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + + AscendC::DataCopyExtParams dataCopyParamsFloat = {1U, sizeof(float), 0U, 0U, 0U}; + AscendC::LocalTensor xTmpTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(axisHCommu * sizeof(int8_t)); + AscendC::LocalTensor xOutFp32Tensor_ = xTmpTensor_.template ReinterpretCast(); + AscendC::LocalTensor tmpLocalTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutCountTensor = (gatherMaskOutTensor.template ReinterpretCast()); + AscendC::GlobalTensor tokGlobal; + AscendC::GlobalTensor tokGlobalInt32; + AscendC::GlobalTensor expandXOutGlobal; + AscendC::GlobalTensor dynamicScalesOutGMTensor_; + dynamicScalesOutGMTensor_.SetGlobalBuffer((__gm__ float *)(gmX1Scale)); + uint32_t beginIdx = 0; + for (uint32_t index = startRankId; index < endRankId; index++) { + uint32_t i = index - startRankId; + if (i > 0) { + gatherMaskOutCountTensor.SetValue( + i, gatherMaskOutCountTensor.GetValue(i - 1) + gatherMaskOutCountTensor.GetValue(index)); + } + uint32_t count = statusTensor_.GetValue(index * INT32_COUNT_PER_BLOCK + 1); + coreTokenCount += count; + beginIdx = gatherMaskOutCountTensor.GetValue(i) - count; + if (isShareExpert && index < sharedExpertRankNum) { + beginIdx += count; + continue; + } + uint32_t winOffset = index; + if (!isShareExpert && moeExpertNumPerRank > 1) { + // count的空间排布,与token数据的空间排布不同,需要转换成数据区的排布偏移 + // srcRank: index % epRankSize + // localExpertId: index / epRankSize + // Addr: (srcRank * moeExpertNumPerRank + localExpertId) * expertPerSizeOnWin + winOffset = (index % epRankSize) * moeExpertNumPerRank + index / epRankSize; + } + GM_ADDR wAddr = (__gm__ uint8_t *)(GET_WIND_ADDR_BY_RANK_ID(epRankId)) + winOffset * expertPerSizeOnWin; + AscendC::SetFlag(0); + for (uint32_t j = 0; j < count; j++) { + tokGlobal.SetGlobalBuffer((__gm__ int8_t *)(wAddr + j * hCommuSize)); + tokGlobalInt32.SetGlobalBuffer((__gm__ int32_t *)(wAddr + j * hCommuSize + hOutSize)); + expandXOutGlobal.SetGlobalBuffer((__gm__ int8_t *)(gmX1) + (beginIdx + j) * tokenLength, tokenLength); + + while (true) { + AscendC::DataCopy(tmpLocalTensor, tokGlobalInt32, INT32_COUNT_PER_BLOCK); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + if (tmpLocalTensor.GetValue(1) == tokenFlag) { + break; + } + } + AscendC::PipeBarrier(); + + AscendC::WaitFlag(0); + AscendC::DataCopy(xTmpTensor_, tokGlobal, axisHCommu); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopyPad(dynamicScalesOutGMTensor_[beginIdx + j], xOutFp32Tensor_[tokenLength / sizeof(float)], + dataCopyParamsFloat); + AscendC::DataCopy(expandXOutGlobal, xTmpTensor_, tokenLength); + AscendC::SetFlag(0); + } + AscendC::WaitFlag(0); + beginIdx += count; + } + AscendC::PipeBarrier(); + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopyExtParams dataCopyOutParams = {1U, static_cast(recvRankNumPerCore * sizeof(int32_t)), 0U, + 0U, 0U}; + AscendC::GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmEpSendCount)); + AscendC::DataCopyPad(sendCountsGlobal[startRankId], gatherMaskOutCountTensor, dataCopyOutParams); +} + +CATLASS_DEVICE +void RecvCoreFunc(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount) +{ + ubOffset = 0; + RecvCount(ubOffset); + + // 先按本地专家分核,再在专家内进一步分核 + uint32_t recvExpertNum = isShareExpert ? epRankSize : expertCntUp; + uint32_t recvCoreNumPerGroup = recvCoreNum / localExpertNum; // 每个group由若干核处理,这里先假定可以整除且不为0 + uint32_t recvRankNumPerCore = epRankSize / recvCoreNumPerGroup; // 每个核处理的rank数量 + uint32_t remainderRankNum = epRankSize % recvCoreNumPerGroup; + + uint32_t groupId = recvCoreIdx / recvCoreNumPerGroup; // 当前核处理的是哪个group + uint32_t recvCoreIdxInGroup = recvCoreIdx % recvCoreNumPerGroup; // 当前核处理的是group中第几个 + uint32_t startRankIdInGroup = recvRankNumPerCore * recvCoreIdxInGroup; // 当前核处理的起始rank + if (recvCoreIdxInGroup < remainderRankNum) { + recvRankNumPerCore += 1; + startRankIdInGroup += recvCoreIdxInGroup; + } else { + startRankIdInGroup += remainderRankNum; + } + uint32_t endRankIdInGroup = startRankIdInGroup + recvRankNumPerCore; + uint32_t startRankId = epRankSize * groupId + startRankIdInGroup; + uint32_t endRankId = epRankSize * groupId + endRankIdInGroup; + + uint32_t coreTokenCount = 0; + + if (startRankId < recvExpertNum) { + // 计算前缀和,以及接收token。这里有隐含约束,下面两个函数与RecvCount的ubOffset入参应保持一致,这样才能拿到有效数据 + GetCumSum(startRankId, recvExpertNum, ubOffset); + RecvToken(gmX1, gmX1Scale, gmEpSendCount, coreTokenCount, startRankId, endRankId, recvRankNumPerCore, ubOffset); + } + + // 接收完成,通过写GM告知C核和计算V核 + AscendC::PipeBarrier(); + AscendC::LocalTensor tmpLocalTensor = resource.ubBuf.template GetBufferByByte(0); + ubOffset += CEIL_UP(UB_BLOCK_SIZE); + tmpLocalTensor.SetValue(CV_FLAG_INDEX, vToCFlag); + tmpLocalTensor.SetValue(GROUP_ID_INDEX, groupId); + tmpLocalTensor.SetValue(SELF_COUNT_INDEX, coreTokenCount); + AscendC::SetFlag(0); + + AscendC::GlobalTensor groupTokenNumStateTensor; + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET)); + AscendC::WaitFlag(0); + AscendC::SetAtomicAdd(); + // 用原子加,各个核收到的token数量加一起,就是专家收到的token数量 + AscendC::DataCopy(groupTokenNumStateTensor[groupId * GROUP_INFO_SIZE], tmpLocalTensor, INT32_COUNT_PER_BLOCK); + AscendC::SetAtomicNone(); + AscendC::PipeBarrier(); +} + +CATLASS_DEVICE +void CompCoreFunc(GM_ADDR gmCVSwapBuff, __gm__ ElementScale *gmScale, __gm__ ElementPerTokenScale *gmTokenScale, + __gm__ float *gmSwigluOutput, uint32_t n, uint32_t k, LayoutScale layoutScale, + LayoutPerTokenScale wholeLayoutPerTokenScale, LayoutOutput layoutOutput) +{ + uint32_t nOut = n / 2; + uint32_t coreNumPerGroup = recvCoreNum / localExpertNum; // 这里假设可以整除 + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(gmCVSwapBuff)); + auto layoutC = layout::RowMajor{L1TileShape::M * aiCoreGroupNum * WORKSPACE_STAGES, L1TileShape::N}; + { + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource); + + uint32_t stageId = 0; + uint32_t target = 1; + uint32_t startCoreIdx = 0; + + AscendC::GlobalTensor groupTokenNumStateTensor; + for (uint32_t groupIdx = 0; groupIdx < localExpertNum; ++groupIdx) { + // 流程与C核类似,等专家token数据,以及计算、软同步 + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET) + + groupIdx * GROUP_INFO_SIZE); + while (true) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(groupTokenNumStateTensor); + __asm__ __volatile__(""); + if (groupTokenNumStateTensor.GetValue(0) == coreNumPerGroup * vToCFlag) { + break; + } + } + uint32_t currentM = groupTokenNumStateTensor.GetValue(GROUP_TOKEN_COUNT); + GemmCoord inGroupProblemShape{currentM, n, k}; + LayoutPerTokenScale layoutPerTokenScale = + wholeLayoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = layoutOutput.GetTileLayout(MakeCoord(currentM, nOut)); + + EpilogueParams epilogueParams{gmScale + gmGroupOffsetScale, + layoutScale, + gmTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + gmSwigluOutput + gmGroupOffsetD, + layoutD}; + + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = + ((compCoreIdx < startCoreIdx) ? (compCoreIdx + aiCoreGroupNum) : compCoreIdx) - startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += aiCoreGroupNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + MatrixCoord offsetC{(stageId * aiCoreGroupNum + aiCoreGroupIdx) * L1TileShape::M, 0}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + auto gmBlockC = gmC[gmOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + CheckSyncFlag(statusDataSpaceGm + SOFT_SYNC_OFFSET, + static_cast(COMP_AIV_CORE_NUM + compCoreIdx), target); // AIV等待的信号在24~48 + target += 1; + blockEpilogue(blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, layoutBlockC); + EncreaseSyncFlag(statusDataSpaceGm + SOFT_SYNC_OFFSET, static_cast(compCoreIdx)); + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetScale += inGroupProblemShape.n(); + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += currentM * nOut; + + startCoreIdx = (startCoreIdx + coreLoops) % aiCoreGroupNum; + } + } + // 清理软同步残留信息,避免影响别处或者下次运行 + AscendC::PipeBarrier(); + AscendC::GlobalTensor softSyncTensor; + softSyncTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + SOFT_SYNC_OFFSET)); + AscendC::LocalTensor tmpZeroLocalTensor = resource.ubBuf.template GetBufferByByte(0); + AscendC::Duplicate(tmpZeroLocalTensor, (int32_t)0, INT32_COUNT_PER_BLOCK); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopy(softSyncTensor[compCoreIdx * SOFT_SYNC_SPACE_SIZE / sizeof(int32_t)], tmpZeroLocalTensor, + INT32_COUNT_PER_BLOCK); + AscendC::DataCopy(softSyncTensor[(compCoreIdx + COMP_AIV_CORE_NUM) * SOFT_SYNC_SPACE_SIZE / sizeof(int32_t)], + tmpZeroLocalTensor, INT32_COUNT_PER_BLOCK); +} + +CATLASS_DEVICE +void AivInitParams(Params const ¶ms) +{ + aiCoreGroupNum = AscendC::GetBlockNum(); + subBlockNum = AscendC::GetSubBlockNum(); + aivIdx = AscendC::GetBlockIdx(); + aiCoreGroupIdx = aivIdx / subBlockNum; + aivStateGlobalCoreIdx = AIV_STATE_SPACE_IDNEX + aivIdx; + + isCompCore = (aivIdx % SUB_AIV_NUM) == 0; // 偶数核做计算 + compCoreNum = COMP_AIV_CORE_NUM; + compCoreIdx = aiCoreGroupIdx; + // 单卡单专家48发48收 + isRecvCore = true; + isSendCore = true; + recvCoreIdx = aivIdx; + sendCoreIdx = aivIdx; + sendCoreNum = SEND_AIV_CORE_NUM; + recvCoreNum = RECV_AIV_CORE_NUM; + + moeExpertNumPerRank = params.moeExpertNumPerRank; + // 单卡多专家改为24收24发 + if (moeExpertNumPerRank > 1) { + isRecvCore = ((aivIdx % ODD_EVEN_BASE) == 0); // 偶数核发送 + isSendCore = ((aivIdx % ODD_EVEN_BASE) == 1); // 基数核接收 + recvCoreIdx = aivIdx / SUB_AIV_NUM; + sendCoreIdx = aivIdx / SUB_AIV_NUM; + sendCoreNum = SEND_AIV_CORE_NUM / SUB_AIV_NUM; + recvCoreNum = RECV_AIV_CORE_NUM / SUB_AIV_NUM; + } + + epRankSize = params.epRankSize; + epRankId = params.epRankId; + expertCntUp = epRankSize * moeExpertNumPerRank; + sharedExpertRankNum = params.sharedExpertRankNum; + hasShareExpert = (sharedExpertRankNum > 0); + isShareExpert = (epRankId < sharedExpertRankNum); + localExpertNum = isShareExpert ? 1 : moeExpertNumPerRank; + moeExpertNum = params.moeExpertNum; + + hOutSize = tokenLength * sizeof(int8_t); + scaleParamPad = TOKEN_EXTRA_SPACE; // 预留512B给量化参数,实际只使用了4B(fp32) + hCommuSize = hOutSize + scaleParamPad; + axisHCommu = hCommuSize / sizeof(int8_t); + axisBS = params.bs; + + stateOffset = STATE_OFFSET; + expertPerSizeOnWin = params.bs * tokenLength * sizeof(XType); + winContext_ = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); + statusDataSpaceGm = (GM_ADDR)(winContext_->localWindowsExp); +} + +CATLASS_DEVICE +void AivInitState() +{ + // 核状态更新,决定使用哪一半空间,以及各种信号的切换 + AscendC::GlobalTensor selfDataStatusTensor; + selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); + dataState = selfDataStatusTensor(aivIdx * UB_ALIGN); + if (dataState == 0) { + selfDataStatusTensor(aivIdx * UB_ALIGN) = 1; + } else { + selfDataStatusTensor(aivIdx * UB_ALIGN) = 0; + } + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); + + // 专家token数据信号 + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivStateGlobalCoreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + cvDataState = selfDataStatusTensor(aivStateGlobalCoreIdx * UB_ALIGN); + if (cvDataState == 0) { + selfDataStatusTensor(aivStateGlobalCoreIdx * UB_ALIGN) = 1; + vToCFlag = V_TO_C_FLAG_1; + } else { + selfDataStatusTensor(aivStateGlobalCoreIdx * UB_ALIGN) = 0; + vToCFlag = V_TO_C_FLAG_2; + } + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivStateGlobalCoreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + + AscendC::PipeBarrier(); + winDataSizeOffset = dataState * epRankSize * expertPerSizeOnWin * moeExpertNumPerRank; + GM_ADDR statusSpaceGm_ = GET_WIND_STATE_ADDR_BY_RANK_ID(epRankId); + AscendC::GlobalTensor selfStatusTensor; + selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusSpaceGm_ + SELF_STATE_OFFSET)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); + state = selfStatusTensor(aivIdx * UB_ALIGN); + if (state == 0) { + sumTarget = (float)1.0; + tokenFlag = TOKEN_FLAG_1; + selfStatusTensor(aivIdx * UB_ALIGN) = 0x3F800000; // 浮点数的1.0 + } else { + sumTarget = 0.0; + tokenFlag = TOKEN_FLAG_2; + selfStatusTensor(aivIdx * UB_ALIGN) = 0; + } + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); +} + +CATLASS_DEVICE +void UpdateAndCleanInfo(__gm__ ElementGroupList_ *ptrGroupList, GM_ADDR gmEpSendCount) +{ + if (aivIdx == aiCoreGroupNum * subBlockNum - 1) { + // 清理专家token数量信息 + AscendC::GlobalTensor groupTokenNumStateTensor; + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET)); + AscendC::LocalTensor tmpZeroLocalTensor = resource.ubBuf.template GetBufferByByte(0); + AscendC::Duplicate(tmpZeroLocalTensor, (int32_t)0, GROUP_INFO_SIZE * localExpertNum); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopy(groupTokenNumStateTensor, tmpZeroLocalTensor, GROUP_INFO_SIZE * localExpertNum); + } + + if (isRecvCore && recvCoreIdx == (recvCoreNum - 1)) { + // 更新group_list信息 + AscendC::GlobalTensor expertTokenNumsOutGMTensor_; + expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)(ptrGroupList)); + AscendC::GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmEpSendCount)); + for (uint32_t localMoeIndex = 0; localMoeIndex < localExpertNum; ++localMoeIndex) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + sendCountsGlobal[localMoeIndex * epRankSize + epRankSize - 1]); + __asm__ __volatile__(""); + uint32_t tokenNum = sendCountsGlobal.GetValue(localMoeIndex * epRankSize + epRankSize - 1); + expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenNum); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + expertTokenNumsOutGMTensor_[localMoeIndex]); + __asm__ __volatile__(""); + } + } +} + +template <> +CATLASS_DEVICE void operator()(Params const ¶ms) +{ + AivInitParams(params); + AivInitState(); + if (isSendCore) { + SendCoreFunc((GM_ADDR)params.gmX, (GM_ADDR)params.gmexpertIds, (GM_ADDR)params.ptrA, + (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx); + } + if (isRecvCore) { + RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount); + } + + auto gmSwigluOutput = reinterpret_cast<__gm__ float *>( + params.ptrWorkspace + sizeof(int32_t) * (L1TileShape::M * aiCoreGroupNum * WORKSPACE_STAGES * L1TileShape::N)); + if (isCompCore) { + CompCoreFunc(params.ptrWorkspace, params.ptrScale, params.ptrPerTokenScale, gmSwigluOutput, + params.problemShape.n(), params.problemShape.k(), params.layoutScale, params.layoutPerTokenScale, + params.layoutOutput); + } + + icache_preload(8); + Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + AscendC::PipeBarrier(); + + UpdateAndCleanInfo(params.ptrGroupList, params.gmEpSendCount); + { + // 量化计算 + AscendC::GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.gmEpSendCount)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(sendCountsGlobal); + __asm__ __volatile__(""); + totalTokenCount = sendCountsGlobal.GetValue(localExpertNum * epRankSize - 1); + AscendC::PipeBarrier(); + typename BlockQuant::Params quantParams{ + gmSwigluOutput, params.layoutOutput, // input: swiglu output + params.ptrDequantScale, params.layoutDequantScale, // output: quant token scale + params.ptrOutput, params.layoutOutput // output: x2 + }; + uint32_t nOut = params.problemShape.n() / 2; + + BlockQuant blockQuant(resource, quantParams); + MatrixCoord quantShape(totalTokenCount, nOut); + MatrixCoord quantBlockShape(16U, 2048U); + Epilogue::Tile::EpilogueHorizontalTileSwizzle quantSwizzle(quantShape, quantBlockShape); + for (uint32_t loopIdx = aiCoreGroupIdx; loopIdx < quantSwizzle.GetLoops(); loopIdx += aiCoreGroupNum) { + auto blockCoord = quantSwizzle.GetTileCoord(loopIdx); + auto actualBlockShape = quantSwizzle.GetActualTileShape(blockCoord); + blockQuant(quantBlockShape, blockCoord, actualBlockShape); + } + } +} + +private: +friend struct AicWaitFunc1; +friend struct AicSetFunc1; + +struct AicWaitFunc1 { + CATLASS_DEVICE + AicWaitFunc1() = default; + + CATLASS_DEVICE + void operator()() const + { + CheckSyncFlag(flagAddr, idx, target); + } + + __gm__ uint8_t *flagAddr; + uint8_t idx; + uint32_t target; +}; + +struct AicSetFunc1 { + CATLASS_DEVICE + AicSetFunc1() = default; + + CATLASS_DEVICE + void operator()() const + { + EncreaseSyncFlag(flagAddr, idx); + } + + __gm__ uint8_t *flagAddr; + uint8_t idx; +}; + +AicWaitFunc1 aicWaitFunc1; +AicSetFunc1 aicSetFunc1; +Arch::Resource resource; + +AscendC::LocalTensor expertIdsTensor_; + +// 卡与专家相关 +uint32_t epRankSize{0}; +uint32_t epRankId{0}; +bool hasShareExpert{false}; +bool isShareExpert{false}; +uint32_t expertCntUp{0}; +uint32_t localExpertNum{0}; +uint32_t sharedExpertRankNum{0}; +uint32_t moeExpertNumPerRank{0}; +uint32_t moeExpertNum{0}; + +// token相关 +uint32_t hOutSize{0}; +uint32_t scaleParamPad{0}; +uint32_t hCommuSize{0}; +uint32_t axisHCommu{0}; +uint32_t axisBS{0}; +uint32_t totalTokenCount{0}; +uint32_t expertIdsCnt{0}; + +// 状态相关 +int32_t tokenFlag{0}; // token到达的flag +int32_t vToCFlag{0}; // V通知C的flag +int32_t dataState{0}; // 当前核的状态,与combine配合 +int32_t cvDataState{0}; // 当前核的状态,CV配合 +int32_t state{0}; // count的flag选择依据 +float sumTarget{0.0}; // count达到的数量 + +// 共享内存相关 +__gm__ HcclOpResParam *winContext_; +GM_ADDR statusDataSpaceGm; +uint32_t stateOffset{0}; +uint64_t expertPerSizeOnWin{0}; +uint64_t winDataSizeOffset{0}; + +// 核上资源相关 +int64_t ubOffset; + +// 分核相关 +bool isSendCore{false}; +bool isRecvCore{false}; +bool isCompCore{false}; // 参与计算deq_swiglu +uint32_t aiCoreGroupNum{0}; +uint32_t aiCoreGroupIdx{0}; +uint32_t subBlockNum{0}; +uint32_t aicNum{0}; +uint32_t sendCoreNum{0}; +uint32_t recvCoreNum{0}; +uint32_t compCoreNum{0}; +uint32_t aivIdx{0}; +uint32_t aicIdx{0}; +uint32_t sendCoreIdx{0}; +uint32_t recvCoreIdx{0}; +uint32_t compCoreIdx{0}; +uint32_t aivStateGlobalCoreIdx{0}; +uint32_t aicStateGlobalCoreIdx{0}; +uint32_t sendToMoeAivNum{0}; +uint32_t sendToShareAivNum{0}; +}; + +} // namespace Catlass::Gemm::Kernel + +namespace Catlass::Gemm::Kernel { + +template +class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using ElementDequantScale = typename BlockQuant::ElementDequantScale; + using LayoutDequantScale = typename BlockQuant::LayoutDequantScale; + using ElementOutput = typename BlockQuant::ElementOutput; + using LayoutOutput = typename BlockQuant::LayoutOutput; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; + using ElementGroupList = ElementGroupList_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList_ *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementOutput *ptrOutput; + LayoutOutput layoutOutput; + __gm__ ElementDequantScale *ptrDequantScale; + LayoutDequantScale layoutDequantScale; + GM_ADDR ptrWorkspace; + + // Methods + CATLASS_DEVICE + Params() {} + + CATLASS_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, + LayoutA const &layoutA_, GM_ADDR ptrB_, LayoutB const &layoutB_, GM_ADDR ptrScale_, + LayoutScale const &layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_, + GM_ADDR ptrDequantScale_, LayoutDequantScale const &layoutDequantScale_, GM_ADDR ptrWorkspace_) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrOutput(reinterpret_cast<__gm__ ElementOutput *>(ptrOutput_)), + layoutOutput(layoutOutput_), + ptrDequantScale(reinterpret_cast<__gm__ ElementDequantScale *>(ptrDequantScale_)), + layoutDequantScale(layoutDequantScale_), + ptrWorkspace(ptrWorkspace_) + {} + }; + + // Methods + CATLASS_DEVICE + GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch() + { + Arch::FlagID flagId = 0; + for (uint32_t stageId = 0; stageId < WORKSPACE_STAGES; ++stageId) { + flagAicFinishStoreList[stageId] = Arch::CrossCoreFlag(flagId++); + flagAivFinishComputeList[stageId] = Arch::CrossCoreFlag(flagId++); + aicWaitFuncList[stageId] = {this, stageId}; + aicSetFuncList[stageId] = {this, stageId}; + } + } + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer(params.ptrB); + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t stageUsed = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + Callback callbackBeforeFixpipe{}; + if (stageUsed == WORKSPACE_STAGES) { + callbackBeforeFixpipe = MakeCallback(&aicWaitFuncList[stageId]); + } else { + ++stageUsed; + } + Callback callbackAfterFixpipe = MakeCallback(&aicSetFuncList[stageId]); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe); + } else { + callbackBeforeFixpipe(); + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape); + callbackAfterFixpipe(); + } + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + while (stageUsed > 0) { + uint32_t aivComputeStageId = + (stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed); + Arch::CrossCoreWaitFlag(flagAivFinishComputeList[aivComputeStageId]); + --stageUsed; + } + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + auto ptrD = reinterpret_cast<__gm__ float *>( + params.ptrWorkspace + sizeof(int32_t) * (L1TileShape::M * coreNum * WORKSPACE_STAGES * L1TileShape::N)); + + uint32_t mActual = groupList.GetValue(params.problemCount - 1); + uint32_t nOut = params.problemShape.n() / 2; + + { + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource); + + uint32_t stageId = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutScale layoutScale = params.layoutScale; + LayoutPerTokenScale layoutPerTokenScale = + params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = params.layoutOutput.GetTileLayout(MakeCoord(currentM, nOut)); + + EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, + layoutScale, + params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + ptrD + gmGroupOffsetD, + layoutD}; + + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + auto gmBlockC = gmC[gmOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + + Arch::CrossCoreWaitFlag(flagAicFinishStoreList[stageId]); + blockEpilogue(blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, layoutBlockC); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishComputeList[stageId]); + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetScale += inGroupProblemShape.n(); + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += currentM * nOut; + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + } + + Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + + { + typename BlockQuant::Params quantParams{ptrD, + params.layoutOutput, + params.ptrDequantScale, + params.layoutDequantScale, + params.ptrOutput, + params.layoutOutput}; + + BlockQuant blockQuant(resource, quantParams); + MatrixCoord quantShape(mActual, nOut); + MatrixCoord quantBlockShape(16U, 2048U); + Epilogue::Tile::EpilogueHorizontalTileSwizzle quantSwizzle(quantShape, quantBlockShape); + for (uint32_t loopIdx = coreIdx; loopIdx < quantSwizzle.GetLoops(); loopIdx += coreNum) { + auto blockCoord = quantSwizzle.GetTileCoord(loopIdx); + auto actualBlockShape = quantSwizzle.GetActualTileShape(blockCoord); + + blockQuant(quantBlockShape, blockCoord, actualBlockShape); + } + } + } + +private: + friend struct AicWaitFunc; + friend struct AicSetFunc; + + struct AicWaitFunc { + using MatmulKernel = GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch< + BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; + + CATLASS_DEVICE + AicWaitFunc() = default; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreWaitFlag(ptr->flagAivFinishComputeList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + struct AicSetFunc { + using MatmulKernel = GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch< + BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; + + CATLASS_DEVICE + AicSetFunc() = default; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(ptr->flagAicFinishStoreList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + Arch::CrossCoreFlag flagAicFinishStoreList[WORKSPACE_STAGES]; + Arch::CrossCoreFlag flagAivFinishComputeList[WORKSPACE_STAGES]; + + AicWaitFunc aicWaitFuncList[WORKSPACE_STAGES]; + AicSetFunc aicSetFuncList[WORKSPACE_STAGES]; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel diff --git a/csrc/deepep/pybind_extension.cpp b/csrc/deepep/pybind_extension.cpp index 9b587840..6e4a4055 100644 --- a/csrc/deepep/pybind_extension.cpp +++ b/csrc/deepep/pybind_extension.cpp @@ -30,11 +30,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) pybind11::class_(m, "Buffer") .def(pybind11::init()) .def("is_available", &deep_ep::Buffer::is_available) + .def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks) .def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank) .def("get_dispatch_layout", &deep_ep::Buffer::get_dispatch_layout) .def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer) .def("intranode_dispatch", &deep_ep::Buffer::intranode_dispatch) .def("intranode_combine", &deep_ep::Buffer::intranode_combine) + .def("internode_dispatch", &deep_ep::Buffer::internode_dispatch) + .def("internode_combine", &deep_ep::Buffer::internode_combine) .def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch) .def("low_latency_combine", &deep_ep::Buffer::low_latency_combine) .def("fused_deep_moe", &deep_ep::Buffer::fused_deep_moe); diff --git a/python/deep_ep/deep_ep/buffer.py b/python/deep_ep/deep_ep/buffer.py index 36477b8d..eff1c8a8 100644 --- a/python/deep_ep/deep_ep/buffer.py +++ b/python/deep_ep/deep_ep/buffer.py @@ -289,6 +289,24 @@ def dispatch( # Default config config = self.get_dispatch_config(self.group_size) if config is None else config + # Internode + if self.runtime.get_num_rdma_ranks() > 1: + return self.internode_dispatch( + x, + handle, + num_tokens_per_rank, + num_tokens_per_rdma_rank, + is_token_in_rank, + num_tokens_per_expert, + topk_idx, + topk_weights, + expert_alignment, + config, + previous_event, + async_finish, + allocate_on_comm_stream, + ) + # Launch the kernel with cached or non-cached mode if isinstance(x, tuple): raise NotImplementedError("Not support fp8") @@ -356,8 +374,6 @@ def dispatch( EventOverlap(event), ) - # noinspection PyTypeChecker - @log_parameters() def combine( self, @@ -394,6 +410,19 @@ def combine( recv_topk_weights: the reduced top-k weights from its dispatch ranks. event: the event after executing the kernel (valid only if `async_finish` is set). """ + # Internode + if self.runtime.get_num_rdma_ranks() > 1: + return self.internode_combine( + x, + handle, + topk_weights, + bias, + config, + previous_event, + async_finish, + allocate_on_comm_stream, + ) + # NOTES: the second `_` is for the sending side, so we should use the third one ( rank_prefix_matrix, @@ -412,6 +441,139 @@ def combine( ) return recv_x, recv_topk_weights, EventOverlap(event) + def internode_dispatch( + self, + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + handle: Optional[Tuple] = None, + num_tokens_per_rank: Optional[torch.Tensor] = None, + num_tokens_per_rdma_rank: Optional[torch.Tensor] = None, + is_token_in_rank: Optional[torch.Tensor] = None, + num_tokens_per_expert: Optional[torch.Tensor] = None, + topk_idx: Optional[torch.Tensor] = None, + topk_weights: Optional[torch.Tensor] = None, + expert_alignment: int = 1, + config: Optional[Config] = None, + previous_event: Optional[EventOverlap] = None, + async_finish: bool = False, + allocate_on_comm_stream: bool = False, + ) -> Tuple[ + Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + List[int], + Tuple, + EventOverlap, + ]: + """ + Internode dispatch implementation, for more details, please refer to the `dispatch` docs. + Normally, you should not directly call this function. + """ + assert config is not None + + # Launch the kernel with cached or non-cached mode + x, x_scales = x if isinstance(x, tuple) else (x, None) + use_quant = False + if handle is not None: + raise NotImplementedError( + "Optional communication handle is not supported yet." + ) + else: + assert ( + num_tokens_per_rank is not None + and is_token_in_rank is not None + and num_tokens_per_expert is not None + ) + ( + recv_x, + recv_x_scales, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens_per_expert_list, + recv_src_idx, + send_head, + offset_inner, + offset_outer, + count_outer, + expand_scales, + event, + ) = self.runtime.internode_dispatch( + x, + x_scales, + topk_idx, + topk_weights, + num_tokens_per_rank, + num_tokens_per_rdma_rank, + is_token_in_rank, + num_tokens_per_expert, + config, + getattr(previous_event, "event", None), + async_finish, + allocate_on_comm_stream, + use_quant, + ) + handle = ( + recv_src_idx, + is_token_in_rank, + send_head, # ep_rank_token_cnt + topk_idx, + topk_weights, + offset_inner, + offset_outer, # token_server_idx + count_outer, + expand_scales, + ) + return ( + (recv_x, recv_x_scales) if x_scales is not None else recv_x, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens_per_expert_list, + handle, + EventOverlap(event), + ) + + def internode_combine( + self, + x: torch.Tensor, + handle: Union[tuple, list], + topk_weights: Optional[torch.Tensor] = None, + bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, + config: Optional[Config] = None, + previous_event: Optional[EventOverlap] = None, + async_finish: bool = False, + allocate_on_comm_stream: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]: + """ + Internode combine implementation, for more details, please refer to the `combine` docs. + Normally, you should not directly call this function. + """ + # assert config is not None + + ( + src_idx, + is_recv_token_in_rank, + send_head, + topk_idx, + topk_weights_ori, + offset_inner, + offset_outer, + count_outer, + expand_scales, + ) = handle + + # Launch the kernel + recv_x, recv_topk_weights, event = self.runtime.internode_combine( + x, + topk_idx, + topk_weights_ori, + src_idx, + send_head, + offset_inner, + offset_outer, + count_outer, + expand_scales, + ) + return recv_x, recv_topk_weights, EventOverlap(event) + # noinspection PyTypeChecker @log_parameters(["topk_idx"]) def low_latency_dispatch(