@@ -76,6 +76,7 @@ namespace
76
76
// Abstract class for routing config
77
77
struct RoutingConfig
78
78
{
79
+ virtual void start (){};
79
80
virtual void setRouting (int * selected_experts, int64_t num_experts, int64_t k, int64_t num_tokens) = 0;
80
81
virtual std::string getName () = 0;
81
82
virtual bool isDeterministic () const = 0;
@@ -143,6 +144,11 @@ struct RandomDistributionRoutingConfig : public RoutingConfig
143
144
" Cannot create random routing distribution. Number of experts does not match the number of weights" );
144
145
}
145
146
147
+ void start ()
148
+ {
149
+ twister.seed (0xD5 );
150
+ }
151
+
146
152
std::string getName () override
147
153
{
148
154
return name;
@@ -208,6 +214,11 @@ struct UniformRoutingConfig : public RoutingConfig
208
214
{
209
215
std::mt19937_64 twister{0xD5 };
210
216
217
+ void start ()
218
+ {
219
+ twister.seed (0xD5 );
220
+ }
221
+
211
222
std::string getName () override
212
223
{
213
224
return " uniform" ;
@@ -522,14 +533,32 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
522
533
523
534
ActivationType mActType = ActivationType::Relu;
524
535
525
- QuantParams mQuantParams {};
536
+ constexpr static int64_t NUM_BUFFERS = 32 ;
537
+
538
+ std::array<QuantParams, NUM_BUFFERS> mQuantParams {};
526
539
bool mUseLora = false ;
527
540
bool mUsePrequantScale = false ;
528
541
int mGroupSize = -1 ;
529
- LoraParams mLoraParams {};
542
+ std::array< LoraParams, NUM_BUFFERS> mLoraParams {};
530
543
531
544
std::optional<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> mSelectedConfig = std::nullopt ;
532
545
546
+ int64_t mBufferIndex = 0 ;
547
+ size_t mWorkspaceSize = 0 ;
548
+ size_t mExpertWeight1Size = 0 ;
549
+ size_t mExpertWeight2Size = 0 ;
550
+ size_t mExpertBias1Size = 0 ;
551
+ size_t mExpertBias2Size = 0 ;
552
+ size_t mInputTensorSize = 0 ;
553
+ size_t mFinalOutputSize = 0 ;
554
+ size_t mSourceToExpandedMapSize = 0 ;
555
+ size_t mScaleProbsSize = 0 ;
556
+ size_t mSelectedExpertsSize = 0 ;
557
+ size_t mExpertFP4WeightSf1Size = 0 ;
558
+ size_t mExpertFP4WeightSf2Size = 0 ;
559
+ size_t mExpertIntScale1Size = 0 ;
560
+ size_t mExpertIntScale2Size = 0 ;
561
+
533
562
template <class T >
534
563
T* allocBuffer (size_t size)
535
564
{
@@ -558,70 +587,97 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
558
587
mGatedMultiplier = mIsGated ? 2 : 1 ;
559
588
auto const gated_inter = mInterSize * mGatedMultiplier ;
560
589
561
- size_t workspace_size
562
- = mMoERunner .getWorkspaceSize (mTotalTokens , mHiddenSize , mInterSize , mNumExperts , mK , mActType , {},
563
- mUseLora , /* use_deepseek_fp8_block_scale=*/ false , /* min_latency_mode=*/ false , mUsePrequantScale );
590
+ mWorkspaceSize = mMoERunner .getWorkspaceSize (mTotalTokens , mHiddenSize , mInterSize , mNumExperts , mK , mActType ,
591
+ {}, mUseLora , /* use_deepseek_fp8_block_scale=*/ false , /* min_latency_mode=*/ false , mUsePrequantScale );
564
592
565
- mWorkspace = allocBuffer<char >(workspace_size );
593
+ mWorkspace = allocBuffer<char >(mWorkspaceSize * NUM_BUFFERS );
566
594
size_t const expert_matrix_size = mNumExperts * mHiddenSize * mInterSize ;
567
595
568
- mExpertWeight1 = allocBuffer<WeightStorage>(expert_matrix_size * mGatedMultiplier / WEIGHT_ELEM_PER_BYTE);
569
- mExpertWeight2 = allocBuffer<WeightStorage>(expert_matrix_size / WEIGHT_ELEM_PER_BYTE);
596
+ mExpertWeight1Size = expert_matrix_size * mGatedMultiplier / WEIGHT_ELEM_PER_BYTE;
597
+ mExpertWeight2Size = expert_matrix_size / WEIGHT_ELEM_PER_BYTE;
598
+ mExpertWeight1 = allocBuffer<WeightStorage>(mExpertWeight1Size * NUM_BUFFERS);
599
+ mExpertWeight2 = allocBuffer<WeightStorage>(mExpertWeight2Size * NUM_BUFFERS);
570
600
571
601
mExpertBias1 = nullptr ;
572
602
mExpertBias2 = nullptr ;
573
603
if (mUseBias )
574
604
{
575
- mExpertBias1 = allocBuffer<DataType>(mNumExperts * gated_inter);
576
- mExpertBias2 = allocBuffer<DataType>(mNumExperts * mHiddenSize );
605
+ mExpertBias1Size = mNumExperts * gated_inter;
606
+ mExpertBias2Size = mNumExperts * mHiddenSize ;
607
+ mExpertBias1 = allocBuffer<DataType>(mExpertBias1Size * NUM_BUFFERS);
608
+ mExpertBias2 = allocBuffer<DataType>(mExpertBias2Size * NUM_BUFFERS);
577
609
}
578
610
579
611
if constexpr (INT_QUANT)
580
612
{
581
- mExpertIntScale1 = allocBuffer<DataType>(mNumExperts * gated_inter);
582
- mExpertIntScale2 = allocBuffer<DataType>(mNumExperts * mHiddenSize );
613
+ mExpertIntScale1Size = mNumExperts * gated_inter;
614
+ mExpertIntScale2Size = mNumExperts * mHiddenSize ;
615
+ mExpertIntScale1 = allocBuffer<DataType>(mExpertIntScale1Size * NUM_BUFFERS);
616
+ mExpertIntScale2 = allocBuffer<DataType>(mExpertIntScale2Size * NUM_BUFFERS);
583
617
584
- mQuantParams = QuantParams::Int (mExpertIntScale1 , mExpertIntScale2 );
618
+ for (int i = 0 ; i < NUM_BUFFERS; i++)
619
+ {
620
+ mQuantParams [i] = QuantParams::Int (
621
+ mExpertIntScale1 + mExpertIntScale1Size * i, mExpertIntScale2 + mExpertIntScale2Size * i);
622
+ }
585
623
}
586
624
else if constexpr (FP8)
587
625
{
588
626
mExpertFP8Scale1 = allocBuffer<float >(mNumExperts );
589
627
mExpertFP8Scale2 = allocBuffer<float >(1 );
590
628
mExpertFP8Scale3 = allocBuffer<float >(mNumExperts );
591
629
592
- mQuantParams = QuantParams::FP8 (mExpertFP8Scale1 , mExpertFP8Scale2 , mExpertFP8Scale3 );
630
+ for (int i = 0 ; i < NUM_BUFFERS; i++)
631
+ {
632
+ mQuantParams [i] = QuantParams::FP8 (mExpertFP8Scale1 , mExpertFP8Scale2 , mExpertFP8Scale3 );
633
+ }
593
634
}
594
635
else if constexpr (ANY_FP4)
595
636
{
596
637
mExpertFP4ActScale1 = allocBuffer<float >(1 );
597
- mExpertFP4WeightSf1 = allocBuffer<ElementSF>(num_experts * gated_inter * mHiddenSize / FP4_VECTOR_SIZE);
638
+ mExpertFP4WeightSf1Size = num_experts * gated_inter * mHiddenSize / FP4_VECTOR_SIZE;
639
+ mExpertFP4WeightSf1 = allocBuffer<ElementSF>(mExpertFP4WeightSf1Size * NUM_BUFFERS);
598
640
mExpertFP4GlobalScale1 = allocBuffer<float >(num_experts);
599
641
600
642
mExpertFP4ActScale2 = allocBuffer<float >(1 );
601
- mExpertFP4WeightSf2 = allocBuffer<ElementSF>(num_experts * mInterSize * mHiddenSize / FP4_VECTOR_SIZE);
643
+ mExpertFP4WeightSf2Size = num_experts * mInterSize * mHiddenSize / FP4_VECTOR_SIZE;
644
+ mExpertFP4WeightSf2 = allocBuffer<ElementSF>(mExpertFP4WeightSf2Size * NUM_BUFFERS);
602
645
mExpertFP4GlobalScale2 = allocBuffer<float >(num_experts);
603
646
604
647
auto func = NVFP4 ? QuantParams::FP4 : QuantParams::FP8MXFP4;
605
- mQuantParams = func (mExpertFP4ActScale1 , mExpertFP4WeightSf1 , mExpertFP4GlobalScale1 , mExpertFP4ActScale2 ,
606
- mExpertFP4WeightSf2 , mExpertFP4GlobalScale2 , false , false );
648
+ for (int i = 0 ; i < NUM_BUFFERS; i++)
649
+ {
650
+ mQuantParams [i] = func (mExpertFP4ActScale1 , mExpertFP4WeightSf1 + mExpertFP4WeightSf1Size * i,
651
+ mExpertFP4GlobalScale1 , mExpertFP4ActScale2 , mExpertFP4WeightSf2 + mExpertFP4WeightSf2Size * i,
652
+ mExpertFP4GlobalScale2 , false , false );
653
+ }
607
654
}
608
655
609
- mSelectedExperts = allocBuffer<int >(mTotalTokens * mK );
610
- mScaleProbs = allocBuffer<float >(mTotalTokens * mK );
611
- mInputTensor = allocBuffer<DataType>(mTotalTokens * mHiddenSize );
612
- mFinalOutput = allocBuffer<OutputType>(mTotalTokens * mHiddenSize );
656
+ mSelectedExpertsSize = mTotalTokens * mK ;
657
+ mSelectedExperts = allocBuffer<int >(mSelectedExpertsSize * NUM_BUFFERS);
658
+ mScaleProbsSize = mTotalTokens * mK ;
659
+ mScaleProbs = allocBuffer<float >(mScaleProbsSize * NUM_BUFFERS);
660
+ mInputTensorSize = mTotalTokens * mHiddenSize ;
661
+ mInputTensor = allocBuffer<DataType>(mInputTensorSize * NUM_BUFFERS);
662
+ mFinalOutputSize = mTotalTokens * mHiddenSize ;
663
+ mFinalOutput = allocBuffer<OutputType>(mFinalOutputSize * NUM_BUFFERS);
613
664
614
- mSourceToExpandedMap = allocBuffer<int >(mTotalTokens * mK );
665
+ mSourceToExpandedMapSize = mTotalTokens * mK ;
666
+ mSourceToExpandedMap = allocBuffer<int >(mSourceToExpandedMapSize * NUM_BUFFERS);
615
667
616
668
mRoutingConfigIndex = routing_config;
617
669
auto tactic = routingConfigCache.at (routing_config);
618
- tactic->setRouting (mSelectedExperts , mNumExperts , mK , mTotalTokens );
670
+ tactic->start ();
671
+ for (int i = 0 ; i < NUM_BUFFERS; i++)
672
+ {
673
+ tactic->setRouting (mSelectedExperts + mSelectedExpertsSize * i, mNumExperts , mK , mTotalTokens );
674
+ }
619
675
620
676
check_cuda_error (cudaStreamSynchronize (streamPtr->get ()));
621
677
}
622
678
623
- cudaGraph_t mGraph {};
624
- cudaGraphExec_t mGraphInstance {};
679
+ std::array< cudaGraph_t, NUM_BUFFERS> mGraph {};
680
+ std::array< cudaGraphExec_t, NUM_BUFFERS> mGraphInstance {};
625
681
626
682
void createGraph (MOEParallelismConfig parallelism_config)
627
683
{
@@ -630,11 +686,15 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
630
686
631
687
NVTX3_SCOPED_RANGE (BuildGraph);
632
688
633
- check_cuda_error (cudaGraphCreate (&mGraph , 0 ));
634
- check_cuda_error (cudaStreamBeginCapture (streamPtr->get (), cudaStreamCaptureModeThreadLocal));
635
- runMoEPermute (parallelism_config);
636
- check_cuda_error (cudaStreamEndCapture (streamPtr->get (), &mGraph ));
637
- check_cuda_error (cudaGraphInstantiate (&mGraphInstance , mGraph , nullptr , nullptr , 0 ));
689
+ for (int i = 0 ; i < NUM_BUFFERS; i++)
690
+ {
691
+ mBufferIndex = i;
692
+ check_cuda_error (cudaGraphCreate (&mGraph [i], 0 ));
693
+ check_cuda_error (cudaStreamBeginCapture (streamPtr->get (), cudaStreamCaptureModeThreadLocal));
694
+ runMoEPermute (parallelism_config);
695
+ check_cuda_error (cudaStreamEndCapture (streamPtr->get (), &mGraph [i]));
696
+ check_cuda_error (cudaGraphInstantiate (&mGraphInstance [i], mGraph [i], nullptr , nullptr , 0 ));
697
+ }
638
698
}
639
699
640
700
void destroyGraph ()
@@ -644,24 +704,28 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
644
704
645
705
NVTX3_SCOPED_RANGE (DestroyGraph);
646
706
647
- check_cuda_error (cudaGraphExecDestroy (mGraphInstance ));
648
- check_cuda_error (cudaGraphDestroy (mGraph ));
707
+ for (int i = 0 ; i < NUM_BUFFERS; i++)
708
+ {
709
+ check_cuda_error (cudaGraphExecDestroy (mGraphInstance [i]));
710
+ check_cuda_error (cudaGraphDestroy (mGraph [i]));
711
+ }
649
712
}
650
713
651
714
float benchmarkLoop (MOEParallelismConfig parallelism_config)
652
715
{
716
+ mBufferIndex = (mBufferIndex + 1 ) % NUM_BUFFERS;
653
717
auto tactic = routingConfigCache.at (mRoutingConfigIndex );
654
718
if (!tactic->isDeterministic ())
655
719
{
656
- tactic->setRouting (mSelectedExperts , mNumExperts , mK , mTotalTokens );
720
+ tactic->setRouting (mSelectedExperts + mSelectedExpertsSize * mBufferIndex , mNumExperts , mK , mTotalTokens );
657
721
}
658
722
659
723
{
660
724
NVTX3_SCOPED_RANGE (BenchmarkLoopIteration);
661
725
check_cuda_error (cudaEventRecord (mStartEvent , streamPtr->get ()));
662
726
if (useCudaGraph)
663
727
{
664
- cudaGraphLaunch (mGraphInstance , streamPtr->get ());
728
+ cudaGraphLaunch (mGraphInstance [ mBufferIndex ] , streamPtr->get ());
665
729
}
666
730
else
667
731
{
@@ -802,17 +866,29 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
802
866
auto stream = streamPtr->get ();
803
867
MoeMinLatencyParams min_latency_params;
804
868
#ifdef USING_OSS_CUTLASS_MOE_GEMM
805
- mMoERunner .runMoe (mInputTensor , nullptr , mSelectedExperts , mUseFinalScale ? mScaleProbs : nullptr ,
806
- mExpertWeight1 , mExpertBias1 , mActType , mExpertWeight2 , mExpertBias2 , mQuantParams , mTotalTokens ,
807
- mHiddenSize , mInterSize , mNumExperts , mK , mWorkspace , mFinalOutput , mSourceToExpandedMap ,
808
- parallelism_config, /* enable_alltoall=*/ false , mUseLora , mLoraParams ,
809
- /* use_deepseek_fp8_block_scale=*/ false , /* min_latency_mode=*/ false , min_latency_params, stream);
869
+ mMoERunner .runMoe (mInputTensor + mInputTensorSize * mBufferIndex , nullptr ,
870
+ mSelectedExperts + mSelectedExpertsSize * mBufferIndex ,
871
+ mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr ,
872
+ mExpertWeight1 + mExpertWeight1Size * mBufferIndex , mExpertBias1 + mExpertBias1Size * mBufferIndex ,
873
+ mActType , mExpertWeight2 + mExpertWeight2Size * mBufferIndex ,
874
+ mExpertBias2 + mExpertBias2Size * mBufferIndex , mQuantParams [mBufferIndex ], mTotalTokens , mHiddenSize ,
875
+ mInterSize , mNumExperts , mK , mWorkspace + mWorkspaceSize * mBufferIndex ,
876
+ mFinalOutput + mFinalOutputSize * mBufferIndex ,
877
+ mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex , parallelism_config,
878
+ /* enable_alltoall=*/ false , mUseLora , mLoraParams [mBufferIndex ],
879
+ /* use_fp8_block_scaling=*/ false , /* min_latency_mode=*/ false , min_latency_params, stream);
810
880
#else
811
- mMoERunner .runMoe (mInputTensor , nullptr , mSelectedExperts , mUseFinalScale ? mScaleProbs : nullptr ,
812
- mExpertWeight1 , mExpertBias1 , mActType , mExpertWeight2 , mExpertBias2 , mQuantParams , mTotalTokens ,
813
- mHiddenSize , mInterSize , mNumExperts , mK , mWorkspace , mFinalOutput , mSourceToExpandedMap ,
814
- parallelism_config, mUseLora , mLoraParams , /* use_deepseek_fp8_block_scale=*/ false ,
815
- /* min_latency_mode=*/ false , min_latency_params, stream);
881
+ mMoERunner .runMoe (mInputTensor + mInputTensorSize * mBufferIndex , nullptr ,
882
+ mSelectedExperts + mSelectedExpertsSize * mBufferIndex ,
883
+ mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr ,
884
+ mExpertWeight1 + mExpertWeight1Size * mBufferIndex , mExpertBias1 + mExpertBias1Size * mBufferIndex ,
885
+ mActType , mExpertWeight2 + mExpertWeight2Size * mBufferIndex ,
886
+ mExpertBias2 + mExpertBias2Size * mBufferIndex , mQuantParams [mBufferIndex ], mTotalTokens , mHiddenSize ,
887
+ mInterSize , mNumExperts , mK , mWorkspace + mWorkspaceSize * mBufferIndex ,
888
+ mFinalOutput + mFinalOutputSize * mBufferIndex ,
889
+ mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex , parallelism_config, mUseLora ,
890
+ mLoraParams [mBufferIndex ],
891
+ /* use_fp8_block_scaling=*/ false , /* min_latency_mode=*/ false , min_latency_params, stream);
816
892
#endif
817
893
}
818
894
0 commit comments