Skip to content

Commit b12be8d

Browse files
committed
Fusing slicing into finalizeMoERoutingKernel.
Signed-off-by: Bo Li <[email protected]> Remove redundant slicing. Signed-off-by: Bo Li <[email protected]> TRT does not support padding, safe to assume padded/unpadded hidden sizes are the same Signed-off-by: Bo Li <[email protected]> Address review comment. Signed-off-by: Bo Li <[email protected]> original_hidden_size -> unpadded_hidden_size Signed-off-by: Bo Li <[email protected]>
1 parent 2f2f5cc commit b12be8d

File tree

12 files changed

+135
-105
lines changed

12 files changed

+135
-105
lines changed

cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "tensorrt_llm/common/cudaUtils.h"
2222
#include "tensorrt_llm/common/quantization.h"
2323
#include "tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h"
24+
#include <cstdint>
2425
#ifdef ENABLE_FP4
2526
#include <cuda_fp4.h>
2627
#endif
@@ -451,7 +452,7 @@ class CutlassMoeFCRunnerInterface
451452
virtual void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf,
452453
int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights,
453454
void const* fc1_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights,
454-
void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size,
455+
void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const orig_hidden_size,
455456
int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr,
456457
void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config,
457458
bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
@@ -480,11 +481,12 @@ class CutlassMoeFCRunnerInterface
480481
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
481482
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
482483
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
483-
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
484-
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
485-
bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
486-
bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
487-
int* num_active_experts_per, int* active_expert_global_ids)
484+
int64_t const hidden_size, int64_t const orig_hidden_size, int64_t const inter_size,
485+
int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array,
486+
bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
487+
MOEParallelismConfig parallelism_config, bool const enable_alltoall,
488+
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
489+
int* active_expert_global_ids)
488490
= 0;
489491

490492
virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
@@ -606,7 +608,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
606608
void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf,
607609
int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights,
608610
void const* fc1_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights,
609-
void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size,
611+
void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const orig_hidden_size,
610612
int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr,
611613
void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config,
612614
bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
@@ -641,11 +643,11 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
641643
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
642644
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
643645
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
644-
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
645-
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
646-
cudaStream_t stream, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
647-
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
648-
int* active_expert_global_ids);
646+
int64_t const hidden_size, int64_t const orig_hidden_size, int64_t const inter_size,
647+
int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array,
648+
bool use_lora, void* fc2_lora, cudaStream_t stream, MOEParallelismConfig parallelism_config,
649+
bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
650+
int* num_active_experts_per, int* active_expert_global_ids);
649651

650652
// Overrides to allow us to forward on to the internal functions with the pointers using the correct type
651653
void gemm1(void const* const input, void* const output, void* const intermediate_result,
@@ -678,11 +680,12 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
678680
float const* const token_topk_permuted_scales, int const* const unpermuted_row_to_permuted_row,
679681
int const* permuted_row_to_unpermuted_row, int const* const token_selected_experts,
680682
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
681-
int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node,
682-
int64_t const experts_per_token, float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora,
683-
bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
684-
bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode,
685-
int* num_active_experts_per, int* active_expert_global_ids) override
683+
int64_t const hidden_size, int64_t const orig_hidden_size, int64_t const inter_size,
684+
int const num_experts_per_node, int64_t const experts_per_token, float const** alpha_scale_ptr_array,
685+
bool use_lora, void* fc2_lora, bool use_deepseek_fp8_block_scale, cudaStream_t stream,
686+
MOEParallelismConfig parallelism_config, bool const enable_alltoall,
687+
cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per,
688+
int* active_expert_global_ids) override
686689
{
687690
auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr;
688691
return Self::gemm2(moe_gemm_runner_, block_scale_gemm_runner, static_cast<T const*>(input), gemm_output,
@@ -691,9 +694,9 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
691694
static_cast<ScaleBiasType const*>(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params,
692695
token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row,
693696
permuted_row_to_unpermuted_row, token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows,
694-
hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora,
695-
stream, parallelism_config, enable_alltoall, config, min_latency_mode, num_active_experts_per,
696-
active_expert_global_ids);
697+
hidden_size, orig_hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array,
698+
use_lora, fc2_lora, stream, parallelism_config, enable_alltoall, config, min_latency_mode,
699+
num_active_experts_per, active_expert_global_ids);
697700
}
698701

699702
virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override
@@ -830,9 +833,9 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
830833
float const* const token_topk_unpermuted_scales, int const* const unpermuted_row_to_permuted_row,
831834
int const* const permuted_row_to_unpermuted_row, int const* const token_selected_experts,
832835
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows,
833-
int64_t const hidden_size, int64_t const inter_size, int64_t const num_experts_per_node, int64_t const k,
834-
MOEParallelismConfig parallelism_config, bool const enable_alltoall, QuantParams& quant_params,
835-
cudaStream_t stream);
836+
int64_t const hidden_size, int64_t const orig_hidden_size, int64_t const inter_size,
837+
int64_t const num_experts_per_node, int64_t const k, MOEParallelismConfig parallelism_config,
838+
bool const enable_alltoall, QuantParams& quant_params, cudaStream_t stream);
836839

837840
T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales,
838841
int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq,
@@ -953,6 +956,7 @@ struct GemmProfilerBackend
953956
int64_t mNumExpertsPerNode{};
954957
int64_t mK{};
955958
int64_t mExpertHiddenSize{};
959+
int64_t mExpertOrigHiddenSize{};
956960
int64_t mExpertInterSize{};
957961
int64_t mGroupSize{};
958962
ActivationType mActivationType{};

cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro
6666
OutputType* reduced_unpermuted_output, ScaleBiasType const* bias, float const* final_scales,
6767
int const* unpermuted_row_to_permuted_row, int const* permuted_row_to_unpermuted_row,
6868
int const* token_selected_experts, int64_t const* expert_first_token_offset, int64_t const num_rows,
69-
int64_t const cols, int64_t const experts_per_token, int64_t const num_experts_per_node,
70-
MOEParallelismConfig parallelism_config, bool const enable_alltoall, cudaStream_t stream);
69+
int64_t const padded_cols, int64_t const orig_cols, int64_t const experts_per_token,
70+
int64_t const num_experts_per_node, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
71+
cudaStream_t stream);
7172

7273
} // namespace cutlass_kernels
7374
} // namespace tensorrt_llm::kernels

0 commit comments

Comments
 (0)