55import torch
66
77from tensorrt_llm ._mnnvl_utils import MnnvlMemory , MnnvlMoe , MoEAlltoallInfo
8- from tensorrt_llm ._utils import logger
8+ from tensorrt_llm ._utils import get_sm_version , logger
99from tensorrt_llm .functional import AllReduceStrategy
1010from tensorrt_llm .mapping import Mapping
1111
1515from ...utils import AuxStreamType , EventType , Fp4QuantizedTensor
1616from .deep_ep_utils import buffer_pool , deep_ep_installed
1717from .interface import MoE
18+ from .moe_backend import MoEBackendSelection
1819from .moe_load_balancer import get_moe_load_balancer
1920from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod ,
21+ DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm ,
2022 FP8QDQFusedMoEMethod , MoEWeightLoadingMode ,
2123 NVFP4CutlassFusedMoEMethod ,
2224 UnquantizedFusedMoEMethod , WInt4AFP8FusedMoEMethod )
@@ -232,6 +234,9 @@ def __init__(
232234 self .enable_dummy_allreduce = os .environ .get (
233235 "TRTLLM_ENABLE_DUMMY_ALLREDUCE" , "0" ) == "1"
234236
237+ # Select MoE backend based on configuration
238+ self .moe_backend = None # Will be initialized after weights are created
239+
235240 def _check_configs (self ):
236241 assert self ._weights_created
237242
@@ -318,7 +323,10 @@ def _get_quant_method(self):
318323 if self .quant_config .layer_quant_mode .has_fp8_qdq ():
319324 return FP8QDQFusedMoEMethod ()
320325 elif self .quant_config .layer_quant_mode .has_fp8_block_scales ():
321- return DeepSeekFP8BlockScalesFusedMoEMethod ()
326+ if get_sm_version () == 100 :
327+ return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm ()
328+ else :
329+ return DeepSeekFP8BlockScalesFusedMoEMethod ()
322330 elif self .quant_config .layer_quant_mode .has_nvfp4 ():
323331 return NVFP4CutlassFusedMoEMethod ()
324332 elif self .quant_config .layer_quant_mode .is_int4_weight_only_per_group (
@@ -341,6 +349,9 @@ def create_weights(self):
341349 self ._weights_created = True
342350 self ._check_configs ()
343351
352+ # Initialize MoE backend after weights are created
353+ self .moe_backend = MoEBackendSelection .select_backend (self )
354+
344355 def dummy_allreduce (self ):
345356 """
346357 Debug function for eliminating imbalance during performance analysis.
@@ -638,6 +649,7 @@ def forward_chunk(
638649 f"Not available alltoall method type: { self .alltoall_method_type !r} "
639650 )
640651
652+ # Original fused_moe call (preserved as reference)
641653 final_hidden_states = torch .ops .trtllm .fused_moe (
642654 x ,
643655 token_selected_slots ,
@@ -665,6 +677,35 @@ def forward_chunk(
665677 tuner_top_k = tuner_top_k ,
666678 )
667679
680+ # Use the selected backend to compute MoE with the same parameters as fused_moe
681+ # final_hidden_states = self.moe_backend.run_moe(
682+ # x,
683+ # token_selected_slots,
684+ # token_final_scales,
685+ # w3_w1_weight.view(weight_dtype),
686+ # None, # w3_w1_bias
687+ # w2_weight.view(weight_dtype),
688+ # None, # w2_bias
689+ # output_dtype,
690+ # quant_scales=quant_scales,
691+ # input_sf=x_sf,
692+ # swizzled_input_sf=False,
693+ # tp_size=self.tp_size,
694+ # tp_rank=self.tp_rank,
695+ # ep_size=ep_size,
696+ # ep_rank=ep_rank,
697+ # cluster_size=cluster_size,
698+ # cluster_rank=cluster_rank,
699+ # enable_alltoall=use_all_to_all,
700+ # use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
701+ # use_w4_group_scaling=use_w4_group_scaling,
702+ # min_latency_mode=False,
703+ # tune_max_num_tokens=self.tune_max_num_tokens,
704+ # tuner_num_tokens=tuner_num_tokens,
705+ # tuner_top_k=tuner_top_k,
706+ # module=self, # Additional parameter for backend to access module properties
707+ # )
708+
668709 if self .layer_load_balancer and is_last_call :
669710 self .layer_load_balancer .start_set_cpu_stage ()
670711
0 commit comments