55import torch
66
77from tensorrt_llm ._mnnvl_utils import MnnvlMemory , MnnvlMoe , MoEAlltoallInfo
8- from tensorrt_llm ._utils import logger
8+ from tensorrt_llm ._utils import get_sm_version
99from tensorrt_llm .functional import AllReduceStrategy
10+ from tensorrt_llm .logger import logger
1011from tensorrt_llm .mapping import Mapping
1112
1213from ...distributed import AllReduce , allgather , reducescatter
1516from ...utils import AuxStreamType , EventType , Fp4QuantizedTensor
1617from .deep_ep_utils import buffer_pool , deep_ep_installed
1718from .interface import MoE
19+ from .moe_backend import MoEBackend , MoEBackendSelection
1820from .moe_load_balancer import get_moe_load_balancer
1921from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod ,
22+ DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm ,
2023 FP8QDQFusedMoEMethod , MoEWeightLoadingMode ,
2124 NVFP4CutlassFusedMoEMethod ,
2225 UnquantizedFusedMoEMethod , WInt4AFP8FusedMoEMethod )
@@ -90,6 +93,9 @@ def __init__(
9093 self .apply_router_weight_on_input = apply_router_weight_on_input
9194 self .layer_idx = layer_idx
9295
96+ # Store original hidden size before any potential padding
97+ self .unpadded_hidden_size = self .hidden_size
98+
9399 moe_load_balancer = get_moe_load_balancer ()
94100 self .layer_load_balancer = None
95101 self .repeat_idx = 0
@@ -227,6 +233,9 @@ def __init__(
227233 self .enable_dummy_allreduce = os .environ .get (
228234 "TRTLLM_ENABLE_DUMMY_ALLREDUCE" , "0" ) == "1"
229235
236+ # MoE backend will be lazily initialized when first accessed (see moe_backend property)
237+ self ._moe_backend_impl = None
238+
230239 def _check_configs (self ):
231240 assert self ._weights_created
232241
@@ -316,7 +325,10 @@ def _get_quant_method(self):
316325 if self .quant_config .layer_quant_mode .has_fp8_qdq ():
317326 return FP8QDQFusedMoEMethod ()
318327 elif self .quant_config .layer_quant_mode .has_fp8_block_scales ():
319- return DeepSeekFP8BlockScalesFusedMoEMethod ()
328+ if get_sm_version () == 100 :
329+ return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm ()
330+ else :
331+ return DeepSeekFP8BlockScalesFusedMoEMethod ()
320332 elif self .quant_config .layer_quant_mode .has_nvfp4 ():
321333 return NVFP4CutlassFusedMoEMethod ()
322334 elif self .quant_config .layer_quant_mode .is_int4_weight_only_per_group (
@@ -339,6 +351,19 @@ def create_weights(self):
339351 self ._weights_created = True
340352 self ._check_configs ()
341353
354+ @property
355+ def moe_backend_impl (self ) -> MoEBackend :
356+ """
357+ Lazily initialize and return the MoE backend.
358+
359+ The backend is selected based on hardware capabilities and quantization
360+ configuration, which are only available after weights are created.
361+ """
362+ if self ._moe_backend_impl is None :
363+ assert self ._weights_created , "Weights must be created before accessing moe_backend"
364+ self ._moe_backend_impl = MoEBackendSelection .select_backend (self )
365+ return self ._moe_backend_impl
366+
342367 def dummy_allreduce (self ):
343368 """
344369 Debug function for eliminating imbalance during performance analysis.
@@ -389,8 +414,6 @@ def forward_chunk(
389414 if self .layer_load_balancer and is_first_call :
390415 self .layer_load_balancer .start_wait_gpu_stage ()
391416
392- use_deepseek_fp8_block_scale = False
393- use_w4_group_scaling = False
394417 weight_dtype = self .w3_w1_weight .dtype
395418
396419 token_selected_experts , token_final_scales = self .routing_method .apply (
@@ -544,9 +567,8 @@ def forward_chunk(
544567 x_sf = x_sf .view ((x_row , - 1 ))
545568
546569 elif self .has_deepseek_fp8_block_scales :
547- use_deepseek_fp8_block_scale = True
570+ pass
548571 elif self .has_w4afp8 :
549- use_w4_group_scaling = True
550572 weight_dtype = torch .quint4x2
551573 else :
552574 raise ValueError (
@@ -569,12 +591,8 @@ def forward_chunk(
569591 sizes = None if use_dp_padding else all_rank_num_tokens )
570592 x_row = x .shape [0 ]
571593
572- ep_size = self .ep_size
573- ep_rank = self .ep_rank
574594 w3_w1_weight = self .w3_w1_weight
575595 w2_weight = self .w2_weight
576- cluster_size = self .cluster_size
577- cluster_rank = self .cluster_rank
578596 quant_scales = self .quant_scales
579597
580598 if self .alltoall_method_type == AlltoallMethodType .MNNVL :
@@ -640,7 +658,8 @@ def forward_chunk(
640658 f"Not available alltoall method type: { self .alltoall_method_type !r} "
641659 )
642660
643- final_hidden_states = torch .ops .trtllm .fused_moe (
661+ final_hidden_states = self .moe_backend_impl .run_moe (
662+ self ,
644663 x ,
645664 token_selected_slots ,
646665 token_final_scales ,
@@ -652,17 +671,8 @@ def forward_chunk(
652671 quant_scales = quant_scales ,
653672 input_sf = x_sf ,
654673 swizzled_input_sf = False ,
655- tp_size = self .tp_size ,
656- tp_rank = self .tp_rank ,
657- ep_size = ep_size ,
658- ep_rank = ep_rank ,
659- cluster_size = cluster_size ,
660- cluster_rank = cluster_rank ,
661- enable_alltoall = use_all_to_all ,
662- use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale ,
663- use_w4_group_scaling = use_w4_group_scaling ,
664674 min_latency_mode = False ,
665- tune_max_num_tokens = self . tune_max_num_tokens ,
675+ use_fused_finalize = True ,
666676 tuner_num_tokens = tuner_num_tokens ,
667677 tuner_top_k = tuner_top_k ,
668678 )
0 commit comments