@@ -76,6 +76,7 @@ namespace
7676// Abstract class for routing config
7777struct RoutingConfig
7878{
79+ virtual void start (){};
7980 virtual void setRouting (int * selected_experts, int64_t num_experts, int64_t k, int64_t num_tokens) = 0;
8081 virtual std::string getName () = 0;
8182 virtual bool isDeterministic () const = 0;
@@ -143,6 +144,11 @@ struct RandomDistributionRoutingConfig : public RoutingConfig
143144 " Cannot create random routing distribution. Number of experts does not match the number of weights" );
144145 }
145146
147+ void start ()
148+ {
149+ twister.seed (0xD5 );
150+ }
151+
146152 std::string getName () override
147153 {
148154 return name;
@@ -208,6 +214,11 @@ struct UniformRoutingConfig : public RoutingConfig
208214{
209215 std::mt19937_64 twister{0xD5 };
210216
217+ void start ()
218+ {
219+ twister.seed (0xD5 );
220+ }
221+
211222 std::string getName () override
212223 {
213224 return " uniform" ;
@@ -522,14 +533,32 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
522533
523534 ActivationType mActType = ActivationType::Relu;
524535
525- QuantParams mQuantParams {};
536+ constexpr static int64_t NUM_BUFFERS = 32 ;
537+
538+ std::array<QuantParams, NUM_BUFFERS> mQuantParams {};
526539 bool mUseLora = false ;
527540 bool mUsePrequantScale = false ;
528541 int mGroupSize = -1 ;
529- LoraParams mLoraParams {};
542+ std::array< LoraParams, NUM_BUFFERS> mLoraParams {};
530543
531544 std::optional<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> mSelectedConfig = std::nullopt ;
532545
546+ int64_t mBufferIndex = 0 ;
547+ size_t mWorkspaceSize = 0 ;
548+ size_t mExpertWeight1Size = 0 ;
549+ size_t mExpertWeight2Size = 0 ;
550+ size_t mExpertBias1Size = 0 ;
551+ size_t mExpertBias2Size = 0 ;
552+ size_t mInputTensorSize = 0 ;
553+ size_t mFinalOutputSize = 0 ;
554+ size_t mSourceToExpandedMapSize = 0 ;
555+ size_t mScaleProbsSize = 0 ;
556+ size_t mSelectedExpertsSize = 0 ;
557+ size_t mExpertFP4WeightSf1Size = 0 ;
558+ size_t mExpertFP4WeightSf2Size = 0 ;
559+ size_t mExpertIntScale1Size = 0 ;
560+ size_t mExpertIntScale2Size = 0 ;
561+
533562 template <class T >
534563 T* allocBuffer (size_t size)
535564 {
@@ -558,70 +587,97 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
558587 mGatedMultiplier = mIsGated ? 2 : 1 ;
559588 auto const gated_inter = mInterSize * mGatedMultiplier ;
560589
561- size_t workspace_size
562- = mMoERunner .getWorkspaceSize (mTotalTokens , mHiddenSize , mInterSize , mNumExperts , mK , mActType , {},
563- mUseLora , /* use_deepseek_fp8_block_scale=*/ false , /* min_latency_mode=*/ false , mUsePrequantScale );
590+ mWorkspaceSize = mMoERunner .getWorkspaceSize (mTotalTokens , mHiddenSize , mInterSize , mNumExperts , mK , mActType ,
591+ {}, mUseLora , /* use_deepseek_fp8_block_scale=*/ false , /* min_latency_mode=*/ false , mUsePrequantScale );
564592
565- mWorkspace = allocBuffer<char >(workspace_size );
593+ mWorkspace = allocBuffer<char >(mWorkspaceSize * NUM_BUFFERS );
566594 size_t const expert_matrix_size = mNumExperts * mHiddenSize * mInterSize ;
567595
568- mExpertWeight1 = allocBuffer<WeightStorage>(expert_matrix_size * mGatedMultiplier / WEIGHT_ELEM_PER_BYTE);
569- mExpertWeight2 = allocBuffer<WeightStorage>(expert_matrix_size / WEIGHT_ELEM_PER_BYTE);
596+ mExpertWeight1Size = expert_matrix_size * mGatedMultiplier / WEIGHT_ELEM_PER_BYTE;
597+ mExpertWeight2Size = expert_matrix_size / WEIGHT_ELEM_PER_BYTE;
598+ mExpertWeight1 = allocBuffer<WeightStorage>(mExpertWeight1Size * NUM_BUFFERS);
599+ mExpertWeight2 = allocBuffer<WeightStorage>(mExpertWeight2Size * NUM_BUFFERS);
570600
571601 mExpertBias1 = nullptr ;
572602 mExpertBias2 = nullptr ;
573603 if (mUseBias )
574604 {
575- mExpertBias1 = allocBuffer<DataType>(mNumExperts * gated_inter);
576- mExpertBias2 = allocBuffer<DataType>(mNumExperts * mHiddenSize );
605+ mExpertBias1Size = mNumExperts * gated_inter;
606+ mExpertBias2Size = mNumExperts * mHiddenSize ;
607+ mExpertBias1 = allocBuffer<DataType>(mExpertBias1Size * NUM_BUFFERS);
608+ mExpertBias2 = allocBuffer<DataType>(mExpertBias2Size * NUM_BUFFERS);
577609 }
578610
579611 if constexpr (INT_QUANT)
580612 {
581- mExpertIntScale1 = allocBuffer<DataType>(mNumExperts * gated_inter);
582- mExpertIntScale2 = allocBuffer<DataType>(mNumExperts * mHiddenSize );
613+ mExpertIntScale1Size = mNumExperts * gated_inter;
614+ mExpertIntScale2Size = mNumExperts * mHiddenSize ;
615+ mExpertIntScale1 = allocBuffer<DataType>(mExpertIntScale1Size * NUM_BUFFERS);
616+ mExpertIntScale2 = allocBuffer<DataType>(mExpertIntScale2Size * NUM_BUFFERS);
583617
584- mQuantParams = QuantParams::Int (mExpertIntScale1 , mExpertIntScale2 );
618+ for (int i = 0 ; i < NUM_BUFFERS; i++)
619+ {
620+ mQuantParams [i] = QuantParams::Int (
621+ mExpertIntScale1 + mExpertIntScale1Size * i, mExpertIntScale2 + mExpertIntScale2Size * i);
622+ }
585623 }
586624 else if constexpr (FP8)
587625 {
588626 mExpertFP8Scale1 = allocBuffer<float >(mNumExperts );
589627 mExpertFP8Scale2 = allocBuffer<float >(1 );
590628 mExpertFP8Scale3 = allocBuffer<float >(mNumExperts );
591629
592- mQuantParams = QuantParams::FP8 (mExpertFP8Scale1 , mExpertFP8Scale2 , mExpertFP8Scale3 );
630+ for (int i = 0 ; i < NUM_BUFFERS; i++)
631+ {
632+ mQuantParams [i] = QuantParams::FP8 (mExpertFP8Scale1 , mExpertFP8Scale2 , mExpertFP8Scale3 );
633+ }
593634 }
594635 else if constexpr (ANY_FP4)
595636 {
596637 mExpertFP4ActScale1 = allocBuffer<float >(1 );
597- mExpertFP4WeightSf1 = allocBuffer<ElementSF>(num_experts * gated_inter * mHiddenSize / FP4_VECTOR_SIZE);
638+ mExpertFP4WeightSf1Size = num_experts * gated_inter * mHiddenSize / FP4_VECTOR_SIZE;
639+ mExpertFP4WeightSf1 = allocBuffer<ElementSF>(mExpertFP4WeightSf1Size * NUM_BUFFERS);
598640 mExpertFP4GlobalScale1 = allocBuffer<float >(num_experts);
599641
600642 mExpertFP4ActScale2 = allocBuffer<float >(1 );
601- mExpertFP4WeightSf2 = allocBuffer<ElementSF>(num_experts * mInterSize * mHiddenSize / FP4_VECTOR_SIZE);
643+ mExpertFP4WeightSf2Size = num_experts * mInterSize * mHiddenSize / FP4_VECTOR_SIZE;
644+ mExpertFP4WeightSf2 = allocBuffer<ElementSF>(mExpertFP4WeightSf2Size * NUM_BUFFERS);
602645 mExpertFP4GlobalScale2 = allocBuffer<float >(num_experts);
603646
604647 auto func = NVFP4 ? QuantParams::FP4 : QuantParams::FP8MXFP4;
605- mQuantParams = func (mExpertFP4ActScale1 , mExpertFP4WeightSf1 , mExpertFP4GlobalScale1 , mExpertFP4ActScale2 ,
606- mExpertFP4WeightSf2 , mExpertFP4GlobalScale2 , false , false );
648+ for (int i = 0 ; i < NUM_BUFFERS; i++)
649+ {
650+ mQuantParams [i] = func (mExpertFP4ActScale1 , mExpertFP4WeightSf1 + mExpertFP4WeightSf1Size * i,
651+ mExpertFP4GlobalScale1 , mExpertFP4ActScale2 , mExpertFP4WeightSf2 + mExpertFP4WeightSf2Size * i,
652+ mExpertFP4GlobalScale2 , false , false );
653+ }
607654 }
608655
609- mSelectedExperts = allocBuffer<int >(mTotalTokens * mK );
610- mScaleProbs = allocBuffer<float >(mTotalTokens * mK );
611- mInputTensor = allocBuffer<DataType>(mTotalTokens * mHiddenSize );
612- mFinalOutput = allocBuffer<OutputType>(mTotalTokens * mHiddenSize );
656+ mSelectedExpertsSize = mTotalTokens * mK ;
657+ mSelectedExperts = allocBuffer<int >(mSelectedExpertsSize * NUM_BUFFERS);
658+ mScaleProbsSize = mTotalTokens * mK ;
659+ mScaleProbs = allocBuffer<float >(mScaleProbsSize * NUM_BUFFERS);
660+ mInputTensorSize = mTotalTokens * mHiddenSize ;
661+ mInputTensor = allocBuffer<DataType>(mInputTensorSize * NUM_BUFFERS);
662+ mFinalOutputSize = mTotalTokens * mHiddenSize ;
663+ mFinalOutput = allocBuffer<OutputType>(mFinalOutputSize * NUM_BUFFERS);
613664
614- mSourceToExpandedMap = allocBuffer<int >(mTotalTokens * mK );
665+ mSourceToExpandedMapSize = mTotalTokens * mK ;
666+ mSourceToExpandedMap = allocBuffer<int >(mSourceToExpandedMapSize * NUM_BUFFERS);
615667
616668 mRoutingConfigIndex = routing_config;
617669 auto tactic = routingConfigCache.at (routing_config);
618- tactic->setRouting (mSelectedExperts , mNumExperts , mK , mTotalTokens );
670+ tactic->start ();
671+ for (int i = 0 ; i < NUM_BUFFERS; i++)
672+ {
673+ tactic->setRouting (mSelectedExperts + mSelectedExpertsSize * i, mNumExperts , mK , mTotalTokens );
674+ }
619675
620676 check_cuda_error (cudaStreamSynchronize (streamPtr->get ()));
621677 }
622678
623- cudaGraph_t mGraph {};
624- cudaGraphExec_t mGraphInstance {};
679+ std::array< cudaGraph_t, NUM_BUFFERS> mGraph {};
680+ std::array< cudaGraphExec_t, NUM_BUFFERS> mGraphInstance {};
625681
626682 void createGraph (MOEParallelismConfig parallelism_config)
627683 {
@@ -630,11 +686,15 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
630686
631687 NVTX3_SCOPED_RANGE (BuildGraph);
632688
633- check_cuda_error (cudaGraphCreate (&mGraph , 0 ));
634- check_cuda_error (cudaStreamBeginCapture (streamPtr->get (), cudaStreamCaptureModeThreadLocal));
635- runMoEPermute (parallelism_config);
636- check_cuda_error (cudaStreamEndCapture (streamPtr->get (), &mGraph ));
637- check_cuda_error (cudaGraphInstantiate (&mGraphInstance , mGraph , nullptr , nullptr , 0 ));
689+ for (int i = 0 ; i < NUM_BUFFERS; i++)
690+ {
691+ mBufferIndex = i;
692+ check_cuda_error (cudaGraphCreate (&mGraph [i], 0 ));
693+ check_cuda_error (cudaStreamBeginCapture (streamPtr->get (), cudaStreamCaptureModeThreadLocal));
694+ runMoEPermute (parallelism_config);
695+ check_cuda_error (cudaStreamEndCapture (streamPtr->get (), &mGraph [i]));
696+ check_cuda_error (cudaGraphInstantiate (&mGraphInstance [i], mGraph [i], nullptr , nullptr , 0 ));
697+ }
638698 }
639699
640700 void destroyGraph ()
@@ -644,24 +704,28 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
644704
645705 NVTX3_SCOPED_RANGE (DestroyGraph);
646706
647- check_cuda_error (cudaGraphExecDestroy (mGraphInstance ));
648- check_cuda_error (cudaGraphDestroy (mGraph ));
707+ for (int i = 0 ; i < NUM_BUFFERS; i++)
708+ {
709+ check_cuda_error (cudaGraphExecDestroy (mGraphInstance [i]));
710+ check_cuda_error (cudaGraphDestroy (mGraph [i]));
711+ }
649712 }
650713
651714 float benchmarkLoop (MOEParallelismConfig parallelism_config)
652715 {
716+ mBufferIndex = (mBufferIndex + 1 ) % NUM_BUFFERS;
653717 auto tactic = routingConfigCache.at (mRoutingConfigIndex );
654718 if (!tactic->isDeterministic ())
655719 {
656- tactic->setRouting (mSelectedExperts , mNumExperts , mK , mTotalTokens );
720+ tactic->setRouting (mSelectedExperts + mSelectedExpertsSize * mBufferIndex , mNumExperts , mK , mTotalTokens );
657721 }
658722
659723 {
660724 NVTX3_SCOPED_RANGE (BenchmarkLoopIteration);
661725 check_cuda_error (cudaEventRecord (mStartEvent , streamPtr->get ()));
662726 if (useCudaGraph)
663727 {
664- cudaGraphLaunch (mGraphInstance , streamPtr->get ());
728+ cudaGraphLaunch (mGraphInstance [ mBufferIndex ] , streamPtr->get ());
665729 }
666730 else
667731 {
@@ -802,17 +866,29 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
802866 auto stream = streamPtr->get ();
803867 MoeMinLatencyParams min_latency_params;
804868#ifdef USING_OSS_CUTLASS_MOE_GEMM
805- mMoERunner .runMoe (mInputTensor , nullptr , mSelectedExperts , mUseFinalScale ? mScaleProbs : nullptr ,
806- mExpertWeight1 , mExpertBias1 , mActType , mExpertWeight2 , mExpertBias2 , mQuantParams , mTotalTokens ,
807- mHiddenSize , mInterSize , mNumExperts , mK , mWorkspace , mFinalOutput , mSourceToExpandedMap ,
808- parallelism_config, /* enable_alltoall=*/ false , mUseLora , mLoraParams ,
809- /* use_deepseek_fp8_block_scale=*/ false , /* min_latency_mode=*/ false , min_latency_params, stream);
869+ mMoERunner .runMoe (mInputTensor + mInputTensorSize * mBufferIndex , nullptr ,
870+ mSelectedExperts + mSelectedExpertsSize * mBufferIndex ,
871+ mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr ,
872+ mExpertWeight1 + mExpertWeight1Size * mBufferIndex , mExpertBias1 + mExpertBias1Size * mBufferIndex ,
873+ mActType , mExpertWeight2 + mExpertWeight2Size * mBufferIndex ,
874+ mExpertBias2 + mExpertBias2Size * mBufferIndex , mQuantParams [mBufferIndex ], mTotalTokens , mHiddenSize ,
875+ mInterSize , mNumExperts , mK , mWorkspace + mWorkspaceSize * mBufferIndex ,
876+ mFinalOutput + mFinalOutputSize * mBufferIndex ,
877+ mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex , parallelism_config,
878+ /* enable_alltoall=*/ false , mUseLora , mLoraParams [mBufferIndex ],
879+ /* use_fp8_block_scaling=*/ false , /* min_latency_mode=*/ false , min_latency_params, stream);
810880#else
811- mMoERunner .runMoe (mInputTensor , nullptr , mSelectedExperts , mUseFinalScale ? mScaleProbs : nullptr ,
812- mExpertWeight1 , mExpertBias1 , mActType , mExpertWeight2 , mExpertBias2 , mQuantParams , mTotalTokens ,
813- mHiddenSize , mInterSize , mNumExperts , mK , mWorkspace , mFinalOutput , mSourceToExpandedMap ,
814- parallelism_config, mUseLora , mLoraParams , /* use_deepseek_fp8_block_scale=*/ false ,
815- /* min_latency_mode=*/ false , min_latency_params, stream);
881+ mMoERunner .runMoe (mInputTensor + mInputTensorSize * mBufferIndex , nullptr ,
882+ mSelectedExperts + mSelectedExpertsSize * mBufferIndex ,
883+ mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr ,
884+ mExpertWeight1 + mExpertWeight1Size * mBufferIndex , mExpertBias1 + mExpertBias1Size * mBufferIndex ,
885+ mActType , mExpertWeight2 + mExpertWeight2Size * mBufferIndex ,
886+ mExpertBias2 + mExpertBias2Size * mBufferIndex , mQuantParams [mBufferIndex ], mTotalTokens , mHiddenSize ,
887+ mInterSize , mNumExperts , mK , mWorkspace + mWorkspaceSize * mBufferIndex ,
888+ mFinalOutput + mFinalOutputSize * mBufferIndex ,
889+ mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex , parallelism_config, mUseLora ,
890+ mLoraParams [mBufferIndex ],
891+ /* use_fp8_block_scaling=*/ false , /* min_latency_mode=*/ false , min_latency_params, stream);
816892#endif
817893 }
818894
0 commit comments