2121#include  " tensorrt_llm/common/cudaUtils.h" 
2222#include  " tensorrt_llm/common/quantization.h" 
2323#include  " tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h" 
24+ #include  < cstdint> 
2425#ifdef  ENABLE_FP4
2526#include  < cuda_fp4.h> 
2627#endif 
@@ -451,7 +452,7 @@ class CutlassMoeFCRunnerInterface
451452    virtual  void  runMoe (void  const * input_activations, void  const * input_sf, bool  const  swizzled_input_sf,
452453        int  const * token_selected_experts, float  const * token_final_scales, void  const * fc1_expert_weights,
453454        void  const * fc1_expert_biases, ActivationParams fc1_activation_type, void  const * fc2_expert_weights,
454-         void  const * fc2_expert_biases, QuantParams quant_params, int64_t  const  num_rows, int64_t  const  hidden_size,
455+         void  const * fc2_expert_biases, QuantParams quant_params, int64_t  const  num_rows, int64_t  const  hidden_size,  int64_t   const  orig_hidden_size, 
455456        int64_t  const  inter_size, int  const  num_experts, int  const  experts_per_token, char * workspace_ptr,
456457        void * final_output, int * unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config,
457458        bool  const  enable_alltoall, bool  use_lora, LoraParams& lora_params, bool  use_deepseek_fp8_block_scale,
@@ -480,11 +481,12 @@ class CutlassMoeFCRunnerInterface
480481        float  const * const  token_topk_permuted_scales, int  const * const  unpermuted_row_to_permuted_row,
481482        int  const * permuted_row_to_unpermuted_row, int  const * const  token_selected_experts,
482483        int64_t  const * const  num_valid_tokens_ptr, int64_t  const  num_rows, int64_t  const  expanded_num_rows,
483-         int64_t  const  hidden_size, int64_t  const  inter_size, int  const  num_experts_per_node,
484-         int64_t  const  experts_per_token, float  const ** alpha_scale_ptr_array, bool  use_lora, void * fc2_lora,
485-         bool  use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
486-         bool  const  enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool  min_latency_mode,
487-         int * num_active_experts_per, int * active_expert_global_ids)
484+         int64_t  const  hidden_size, int64_t  const  orig_hidden_size, int64_t  const  inter_size,
485+         int  const  num_experts_per_node, int64_t  const  experts_per_token, float  const ** alpha_scale_ptr_array,
486+         bool  use_lora, void * fc2_lora, bool  use_deepseek_fp8_block_scale, cudaStream_t stream,
487+         MOEParallelismConfig parallelism_config, bool  const  enable_alltoall,
488+         cutlass_extensions::CutlassGemmConfig config, bool  min_latency_mode, int * num_active_experts_per,
489+         int * active_expert_global_ids)
488490        = 0;
489491
490492    virtual  std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
@@ -606,7 +608,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
606608    void  runMoe (void  const * input_activations, void  const * input_sf, bool  const  swizzled_input_sf,
607609        int  const * token_selected_experts, float  const * token_final_scales, void  const * fc1_expert_weights,
608610        void  const * fc1_expert_biases, ActivationParams fc1_activation_type, void  const * fc2_expert_weights,
609-         void  const * fc2_expert_biases, QuantParams quant_params, int64_t  const  num_rows, int64_t  const  hidden_size,
611+         void  const * fc2_expert_biases, QuantParams quant_params, int64_t  const  num_rows, int64_t  const  hidden_size,  int64_t   const  orig_hidden_size, 
610612        int64_t  const  inter_size, int  const  num_experts, int  const  experts_per_token, char * workspace_ptr,
611613        void * final_output, int * unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config,
612614        bool  const  enable_alltoall, bool  use_lora, LoraParams& lora_params, bool  use_deepseek_fp8_block_scale,
@@ -641,11 +643,11 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
641643        float  const * const  token_topk_permuted_scales, int  const * const  unpermuted_row_to_permuted_row,
642644        int  const * permuted_row_to_unpermuted_row, int  const * const  token_selected_experts,
643645        int64_t  const * const  num_valid_tokens_ptr, int64_t  const  num_rows, int64_t  const  expanded_num_rows,
644-         int64_t  const  hidden_size, int64_t  const  inter_size,  int  const  num_experts_per_node ,
645-         int64_t  const  experts_per_token, float  const ** alpha_scale_ptr_array,  bool  use_lora,  void * fc2_lora ,
646-         cudaStream_t stream, MOEParallelismConfig parallelism_config,  bool   const  enable_alltoall ,
647-         cutlass_extensions::CutlassGemmConfig config, bool  min_latency_mode,  int * num_active_experts_per ,
648-         int * active_expert_global_ids);
646+         int64_t  const  hidden_size, int64_t  const  orig_hidden_size,  int64_t  const  inter_size ,
647+         int   const  num_experts_per_node,  int64_t  const  experts_per_token, float  const ** alpha_scale_ptr_array,
648+         bool  use_lora,  void * fc2_lora, cudaStream_t stream, MOEParallelismConfig parallelism_config ,
649+         bool   const  enable_alltoall,  cutlass_extensions::CutlassGemmConfig config, bool  min_latency_mode,
650+         int * num_active_experts_per,  int *  active_expert_global_ids);
649651
650652    //  Overrides to allow us to forward on to the internal functions with the pointers using the correct type
651653    void  gemm1 (void  const * const  input, void * const  output, void * const  intermediate_result,
@@ -678,11 +680,12 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
678680        float  const * const  token_topk_permuted_scales, int  const * const  unpermuted_row_to_permuted_row,
679681        int  const * permuted_row_to_unpermuted_row, int  const * const  token_selected_experts,
680682        int64_t  const * const  num_valid_tokens_ptr, int64_t  const  num_rows, int64_t  const  expanded_num_rows,
681-         int64_t  const  hidden_size, int64_t  const  inter_size, int  const  num_experts_per_node,
682-         int64_t  const  experts_per_token, float  const ** alpha_scale_ptr_array, bool  use_lora, void * fc2_lora,
683-         bool  use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config,
684-         bool  const  enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool  min_latency_mode,
685-         int * num_active_experts_per, int * active_expert_global_ids) override 
683+         int64_t  const  hidden_size, int64_t  const  orig_hidden_size, int64_t  const  inter_size,
684+         int  const  num_experts_per_node, int64_t  const  experts_per_token, float  const ** alpha_scale_ptr_array,
685+         bool  use_lora, void * fc2_lora, bool  use_deepseek_fp8_block_scale, cudaStream_t stream,
686+         MOEParallelismConfig parallelism_config, bool  const  enable_alltoall,
687+         cutlass_extensions::CutlassGemmConfig config, bool  min_latency_mode, int * num_active_experts_per,
688+         int * active_expert_global_ids) override 
686689    {
687690        auto * block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner () : nullptr ;
688691        return  Self::gemm2 (moe_gemm_runner_, block_scale_gemm_runner, static_cast <T const *>(input), gemm_output,
@@ -691,9 +694,9 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
691694            static_cast <ScaleBiasType const *>(fc2_int_scales), fc2_fp8_dequant, fc2_fp4_act_flat, quant_params,
692695            token_topk_unpermuted_scales, token_topk_permuted_scales, unpermuted_row_to_permuted_row,
693696            permuted_row_to_unpermuted_row, token_selected_experts, num_valid_tokens_ptr, num_rows, expanded_num_rows,
694-             hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora ,
695-             stream, parallelism_config, enable_alltoall, config, min_latency_mode, num_active_experts_per ,
696-             active_expert_global_ids);
697+             hidden_size, orig_hidden_size,  inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array,
698+             use_lora, fc2_lora,  stream, parallelism_config, enable_alltoall, config, min_latency_mode,
699+             num_active_experts_per,  active_expert_global_ids);
697700    }
698701
699702    virtual  size_t  getGemmWorkspaceSize (int  num_experts_per_node) const  override 
@@ -830,9 +833,9 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
830833        float  const * const  token_topk_unpermuted_scales, int  const * const  unpermuted_row_to_permuted_row,
831834        int  const * const  permuted_row_to_unpermuted_row, int  const * const  token_selected_experts,
832835        int64_t  const * const  num_valid_tokens_ptr, int64_t  const  num_rows, int64_t  const  expanded_num_rows,
833-         int64_t  const  hidden_size, int64_t  const  inter_size , int64_t  const  num_experts_per_node,  int64_t   const  k ,
834-         MOEParallelismConfig parallelism_config,  bool  const  enable_alltoall, QuantParams& quant_params ,
835-         cudaStream_t stream);
836+         int64_t  const  hidden_size, int64_t  const  orig_hidden_size , int64_t  const  inter_size ,
837+         int64_t   const  num_experts_per_node,  int64_t  const  k, MOEParallelismConfig parallelism_config ,
838+         bool   const  enable_alltoall, QuantParams& quant_params,  cudaStream_t stream);
836839
837840    T const * applyPrequantScale (void * smoothed_act, void  const * permuted_data, void  const * prequant_scales,
838841        int64_t  const * num_valid_tokens_ptr, int64_t  const  expanded_num_rows, int64_t  const  seq_len, bool  const  use_awq,
@@ -953,6 +956,7 @@ struct GemmProfilerBackend
953956    int64_t  mNumExpertsPerNode {};
954957    int64_t  mK {};
955958    int64_t  mExpertHiddenSize {};
959+     int64_t  mExpertOrigHiddenSize {};
956960    int64_t  mExpertInterSize {};
957961    int64_t  mGroupSize {};
958962    ActivationType mActivationType {};
0 commit comments