2020
2121#include < nlohmann/json.hpp>
2222
23+ #ifdef USING_OSS_CUTLASS_MOE_GEMM
24+ #include " tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h"
25+ #else
2326#include " moe_kernels.h"
27+ #endif
28+
29+ #include " tensorrt_llm/kernels/cutlass_kernels/include/cutlass_kernel_selector.h"
30+
2431#include " tensorrt_llm/common/cudaUtils.h"
2532#include " tensorrt_llm/common/memoryUtils.h"
2633#include " tensorrt_llm/common/nvtxUtils.h"
@@ -42,6 +49,12 @@ using namespace tensorrt_llm::common;
4249using namespace tensorrt_llm ::runtime;
4350using namespace tensorrt_llm ::cutlass_extensions;
4451
52+ using namespace CUTLASS_MOE_GEMM_KERNELS_NAMESPACE ;
53+ using CUTLASS_MOE_GEMM_NAMESPACE::TmaWarpSpecializedGroupedGemmInput;
54+ using CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::CutlassMoeFCRunner;
55+ using CUTLASS_MOE_GEMM_NAMESPACE::ActivationType;
56+ using CUTLASS_MOE_GEMM_NAMESPACE::isGatedActivation;
57+
4558static BufferManager::CudaStreamPtr streamPtr;
4659static std::unique_ptr<BufferManager> bufferManager;
4760static int deviceCount;
@@ -485,7 +498,7 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
485498 bool mIsGated = false ;
486499 int mGatedMultiplier = 1 ;
487500
488- tensorrt_llm:: ActivationType mActType = tensorrt_llm:: ActivationType::Relu;
501+ ActivationType mActType = ActivationType::Relu;
489502
490503 QuantParams mQuantParams {};
491504 bool mUseLora = false ;
@@ -650,9 +663,15 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
650663 " Tactic Profiling GEMM " + std::to_string (static_cast <int >(gemm_to_profile)));
651664
652665 GemmProfilerBackend profiler;
666+ #ifdef USING_OSS_CUTLASS_MOE_GEMM
667+ profiler.init (mMoERunner , gemm_to_profile, typeToDtypeID<DataType>(), typeToDtypeID<WeightType>(),
668+ typeToDtypeID<OutputType>(), mNumExperts , mK , mHiddenSize , mInterSize , mGroupSize , mActType , mUseBias ,
669+ mUseLora , /* min_latency_mode=*/ false , /* need_weights=*/ true , parallelism_config, /* enable_alltoall=*/ false );
670+ #else
653671 profiler.init (mMoERunner , gemm_to_profile, typeToDtypeID<DataType>(), typeToDtypeID<WeightType>(),
654672 typeToDtypeID<OutputType>(), mNumExperts , mK , mHiddenSize , mInterSize , mGroupSize , mActType , mUseBias ,
655673 mUseLora , /* min_latency_mode=*/ false , /* need_weights=*/ true , parallelism_config);
674+ #endif
656675 auto workspace_size = profiler.getWorkspaceSize (mTotalTokens );
657676 auto workspace = bufferManager->gpu (workspace_size);
658677
@@ -760,11 +779,19 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
760779 {
761780 auto stream = streamPtr->get ();
762781 MoeMinLatencyParams min_latency_params;
782+ #ifdef USING_OSS_CUTLASS_MOE_GEMM
763783 mMoERunner .runMoe (mInputTensor , nullptr , mSelectedExperts , mUseFinalScale ? mScaleProbs : nullptr ,
764784 mExpertWeight1 , mExpertBias1 , mActType , mExpertWeight2 , mExpertBias2 , mQuantParams , mTotalTokens ,
765785 mHiddenSize , mInterSize , mNumExperts , mK , mWorkspace , mFinalOutput , mSourceToExpandedMap ,
766- parallelism_config, mUseLora , mLoraParams ,
786+ parallelism_config, /* enable_alltoall= */ false , mUseLora , mLoraParams ,
767787 /* use_deepseek_fp8_block_scale=*/ false , /* min_latency_mode=*/ false , min_latency_params, stream);
788+ #else
789+ mMoERunner .runMoe (mInputTensor , nullptr , mSelectedExperts , mUseFinalScale ? mScaleProbs : nullptr ,
790+ mExpertWeight1 , mExpertBias1 , mActType , mExpertWeight2 , mExpertBias2 , mQuantParams , mTotalTokens ,
791+ mHiddenSize , mInterSize , mNumExperts , mK , mWorkspace , mFinalOutput , mSourceToExpandedMap ,
792+ parallelism_config, mUseLora , mLoraParams , /* use_deepseek_fp8_block_scale=*/ false ,
793+ /* min_latency_mode=*/ false , min_latency_params, stream);
794+ #endif
768795 }
769796
770797 void runBenchmark (benchmark::State& state);
0 commit comments