From 28c20d9bb2c53b67c9433c6de58365a3d5045f88 Mon Sep 17 00:00:00 2001 From: Sunny-bot1 <592045536@qq.com> Date: Mon, 15 Sep 2025 11:33:19 +0800 Subject: [PATCH] fix wint8 ep --- custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu | 192 ++++++------------ 1 file changed, 63 insertions(+), 129 deletions(-) diff --git a/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu b/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu index 1c3a45e50e..fe01400f00 100644 --- a/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu +++ b/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu @@ -448,137 +448,71 @@ void EPMoeDispatchKernel(const paddle::Tensor& input, auto place = input.place(); const int gridx = min(132 * 8, num_rows); if (moe_quant_type == "w4a8") { - if (num_experts_per_rank == 8) { - permute_x_kernel<<>>( - input.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, - moe_topk, - num_rows, - token_nums_this_rank, - hidden_size, - permute_input->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - expert_idx_per_token->data(), - 127.0, - -127.0 - ); - } else if (num_experts_per_rank == 16) { - permute_x_kernel<<>>( - input.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, - moe_topk, - num_rows, - token_nums_this_rank, - hidden_size, - permute_input->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - expert_idx_per_token->data(), - 127.0, - -127.0 - ); - } + DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, + permute_x_kernel<<>>( + input.data(), + topk_ids.data(), + topk_weights.data(), + token_nums_per_expert.data(), + up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, + moe_topk, + num_rows, + token_nums_this_rank, + hidden_size, + permute_input->data(), + permute_indices_per_token->data(), + dst_weights->data(), + dst_indices->data(), + cumsum_idx_gpu->data(), + token_nums_per_expert_cumsum->data(), + expert_idx_per_token->data(), + 127.0, + -127.0 + );) } else if (moe_quant_type == "w4afp8") { - if (num_experts_per_rank == 8) { - permute_x_kernel<<>>( - input.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, - moe_topk, - num_rows, - token_nums_this_rank, - hidden_size, - permute_input->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - expert_idx_per_token->data(), - 448.0f, - -448.0f - ); - } else if (num_experts_per_rank == 16) { - permute_x_kernel<<>>( - input.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, - moe_topk, - num_rows, - token_nums_this_rank, - hidden_size, - permute_input->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - expert_idx_per_token->data(), - 448.0f, - -448.0f - ); - } + DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, + permute_x_kernel<<>>( + input.data(), + topk_ids.data(), + topk_weights.data(), + token_nums_per_expert.data(), + up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, + moe_topk, + num_rows, + token_nums_this_rank, + hidden_size, + permute_input->data(), + permute_indices_per_token->data(), + dst_weights->data(), + dst_indices->data(), + cumsum_idx_gpu->data(), + token_nums_per_expert_cumsum->data(), + expert_idx_per_token->data(), + 448.0f, + -448.0f + );) } else { - if (num_experts_per_rank == 8) { - permute_x_kernel<<>>( - input.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, - moe_topk, - num_rows, - token_nums_this_rank, - hidden_size, - permute_input->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - expert_idx_per_token->data(), - 127.0, - -127.0 - ); - } else if (num_experts_per_rank == 16) { - permute_x_kernel<<>>( - input.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, - moe_topk, - num_rows, - token_nums_this_rank, - hidden_size, - permute_input->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - expert_idx_per_token->data(), - 127.0, - -127.0 - ); - } + DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, + permute_x_kernel<<>>( + input.data(), + topk_ids.data(), + topk_weights.data(), + token_nums_per_expert.data(), + up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, + moe_topk, + num_rows, + token_nums_this_rank, + hidden_size, + permute_input->data(), + permute_indices_per_token->data(), + dst_weights->data(), + dst_indices->data(), + cumsum_idx_gpu->data(), + token_nums_per_expert_cumsum->data(), + expert_idx_per_token->data(), + 127.0, + -127.0 + );) } }