1515from ...utils import AuxStreamType , EventType , Fp4QuantizedTensor
1616from .deep_ep_utils import buffer_pool , deep_ep_installed
1717from .interface import MoE
18- from .moe_backend import MoEBackendSelection
18+ from .moe_backend import MoEBackend , MoEBackendSelection
1919from .moe_load_balancer import get_moe_load_balancer
2020from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod ,
2121 DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm ,
@@ -234,8 +234,8 @@ def __init__(
234234 self .enable_dummy_allreduce = os .environ .get (
235235 "TRTLLM_ENABLE_DUMMY_ALLREDUCE" , "0" ) == "1"
236236
237- # Select MoE backend based on configuration
238- self .moe_backend = None # Will be initialized after weights are created
237+ # MoE backend will be lazily initialized when first accessed (see moe_backend property)
238+ self ._moe_backend_impl = None
239239
240240 def _check_configs (self ):
241241 assert self ._weights_created
@@ -365,8 +365,18 @@ def create_weights(self):
365365 self ._weights_created = True
366366 self ._check_configs ()
367367
368- # Initialize MoE backend after weights are created
369- self .moe_backend = MoEBackendSelection .select_backend (self )
368+ @property
369+ def moe_backend_impl (self ) -> MoEBackend :
370+ """
371+ Lazily initialize and return the MoE backend.
372+
373+ The backend is selected based on hardware capabilities and quantization
374+ configuration, which are only available after weights are created.
375+ """
376+ if self ._moe_backend_impl is None :
377+ assert self ._weights_created , "Weights must be created before accessing moe_backend"
378+ self ._moe_backend_impl = MoEBackendSelection .select_backend (self )
379+ return self ._moe_backend_impl
370380
371381 def dummy_allreduce (self ):
372382 """
@@ -422,8 +432,6 @@ def forward_chunk(
422432 if self .layer_load_balancer and is_first_call :
423433 self .layer_load_balancer .start_wait_gpu_stage ()
424434
425- use_deepseek_fp8_block_scale = False
426- use_w4_group_scaling = False
427435 weight_dtype = self .w3_w1_weight .dtype
428436
429437 token_selected_experts , token_final_scales = self .routing_method .apply (
@@ -578,9 +586,8 @@ def forward_chunk(
578586 x_sf = x_sf .view ((x_row , - 1 ))
579587
580588 elif self .has_deepseek_fp8_block_scales :
581- use_deepseek_fp8_block_scale = True
589+ pass
582590 elif self .has_w4afp8 :
583- use_w4_group_scaling = True
584591 weight_dtype = torch .quint4x2
585592 else :
586593 raise ValueError (
@@ -603,12 +610,12 @@ def forward_chunk(
603610 sizes = None if use_dp_padding else all_rank_num_tokens )
604611 x_row = x .shape [0 ]
605612
606- ep_size = self .ep_size
607- ep_rank = self .ep_rank
613+ # ep_size = self.ep_size
614+ # ep_rank = self.ep_rank
608615 w3_w1_weight = self .w3_w1_weight
609616 w2_weight = self .w2_weight
610- cluster_size = self .cluster_size
611- cluster_rank = self .cluster_rank
617+ # cluster_size = self.cluster_size
618+ # cluster_rank = self.cluster_rank
612619 quant_scales = self .quant_scales
613620
614621 if use_postquant_alltoall :
@@ -697,8 +704,9 @@ def forward_chunk(
697704 # tuner_top_k=tuner_top_k,
698705 # )
699706
700- # Use the selected backend to compute MoE with the same parameters as fused_moe
701- final_hidden_states = self .moe_backend .run_moe (
707+ # Use backend interface with module as first parameter for automatic configuration extraction
708+ final_hidden_states = self .moe_backend_impl .run_moe (
709+ self , # Module as first parameter
702710 x ,
703711 token_selected_slots ,
704712 token_final_scales ,
@@ -710,21 +718,11 @@ def forward_chunk(
710718 quant_scales = quant_scales ,
711719 input_sf = x_sf ,
712720 swizzled_input_sf = False ,
713- tp_size = self .tp_size ,
714- tp_rank = self .tp_rank ,
715- ep_size = ep_size ,
716- ep_rank = ep_rank ,
717- cluster_size = cluster_size ,
718- cluster_rank = cluster_rank ,
719- enable_alltoall = use_all_to_all ,
720- use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale ,
721- use_w4_group_scaling = use_w4_group_scaling ,
721+ # Only need to pass runtime-variable parameters
722722 min_latency_mode = False ,
723- tune_max_num_tokens = self . tune_max_num_tokens ,
723+ use_fused_finalize = True ,
724724 tuner_num_tokens = tuner_num_tokens ,
725725 tuner_top_k = tuner_top_k ,
726- module =
727- self , # Additional parameter for backend to access module properties
728726 )
729727
730728 # print(
0 commit comments