Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 63 additions & 129 deletions custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu
Original file line number Diff line number Diff line change
Expand Up @@ -453,137 +453,71 @@ void EPMoeDispatchKernel(const paddle::Tensor& input,
auto place = input.place();
const int gridx = min(132 * 8, num_rows);
if (moe_quant_type == "w4a8") {
if (num_experts_per_rank == 8) {
permute_x_kernel<data_t, int8_t, 8><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<int8_t>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
127.0,
-127.0
);
} else if (num_experts_per_rank == 16) {
permute_x_kernel<data_t, int8_t, 16><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<int8_t>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
127.0,
-127.0
);
}
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
permute_x_kernel<data_t, int8_t, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<int8_t>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
127.0,
-127.0
);)
} else if (moe_quant_type == "w4afp8") {
if (num_experts_per_rank == 8) {
permute_x_kernel<data_t, data_t_fp8, 8, 512><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<data_t_fp8>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
448.0f,
-448.0f
);
} else if (num_experts_per_rank == 16) {
permute_x_kernel<data_t, data_t_fp8, 16, 512><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<data_t_fp8>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
448.0f,
-448.0f
);
}
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
permute_x_kernel<data_t, data_t_fp8, NUM_EXPERTS_PER_RANK, 512><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<data_t_fp8>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
448.0f,
-448.0f
);)
} else {
if (num_experts_per_rank == 8) {
permute_x_kernel<data_t, data_t, 8><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<data_t>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
127.0,
-127.0
);
} else if (num_experts_per_rank == 16) {
permute_x_kernel<data_t, data_t, 16><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<data_t>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
127.0,
-127.0
);
}
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
permute_x_kernel<data_t, data_t, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
input.data<data_t>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data<float>() : nullptr,
moe_topk,
num_rows,
token_nums_this_rank,
hidden_size,
permute_input->data<data_t>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
expert_idx_per_token->data<int64_t>(),
127.0,
-127.0
);)
}
}

Expand Down