11from collections .abc import Callable
2- from functools import partial
32from typing import Dict , List , Optional , Tuple , Union
43
54import torch
@@ -66,7 +65,6 @@ def __init__(
6665 enable_fused_gemm_swiglu : bool = False ,
6766 enable_fused_gemm_attn_scaling : bool = False ,
6867 enable_trtllm_gen : bool = False ,
69- post_load_weights_hook : Optional [Callable ] = None ,
7068 ):
7169 # First, initialize the base class.
7270 super ().__init__ (
@@ -88,7 +86,6 @@ def __init__(
8886 self .enable_fused_gemm_swiglu = enable_fused_gemm_swiglu
8987 self .enable_fused_gemm_attn_scaling = enable_fused_gemm_attn_scaling
9088 self .enable_trtllm_gen = enable_trtllm_gen
91- self .post_load_weights_hook = post_load_weights_hook
9289 self .position_ids = None
9390
9491 def load_weights (self , weights : List [Dict ]):
@@ -123,9 +120,6 @@ def load_weights(self, weights: List[Dict]):
123120 self .weight .view (torch .uint8 ),
124121 128 ).view (torch .float8_e4m3fn )
125122
126- if self .post_load_weights_hook is not None :
127- self .post_load_weights_hook (self )
128-
129123 # Override apply_linear instead of forward so that we can reuse the AllReduce/AllGather logic in the parent class.
130124 def apply_linear (
131125 self ,
@@ -298,17 +292,6 @@ def __init__(self,
298292 enable_trtllm_gen = True ,
299293 )
300294
301- # After loading both gate_up_proj and down_proj, we need to set the scales needed by the special kernels and by
302- # the trtllm-gen gemm+swiglu kernel.
303- def post_load_weights_hook (gate_up_proj , down_proj ):
304- if gate_up_proj .has_fp8_qdq :
305- # For the special gemm+swiglu kernel, we need to set the inverse of the output scale, which is the inverse
306- # of down_proj's combined input scale.
307- gate_up_proj .inv_output_scale = 1.0 / down_proj .input_scale
308- # For the trtllm-gen gemm+swiglu kernel, we need to set the global scale, which is gate_up_proj's
309- # combined input scale times inv_output_scale.
310- gate_up_proj .trtllm_gen_global_scale = gate_up_proj .combined_scale * gate_up_proj .inv_output_scale
311-
312295 self .down_proj = Llama4MinLatencyLinear (
313296 self .intermediate_size ,
314297 self .hidden_size ,
@@ -320,10 +303,19 @@ def post_load_weights_hook(gate_up_proj, down_proj):
320303 reduce_output = reduce_output ,
321304 skip_create_weights_in_init = config .skip_create_weights_in_init ,
322305 enable_trtllm_gen = True ,
323- post_load_weights_hook = partial (post_load_weights_hook ,
324- self .gate_up_proj ),
325306 )
326307
308+ # After loading both gate_up_proj and down_proj, we need to set the scales needed by the special kernels and by
309+ # the trtllm-gen gemm+swiglu kernel.
310+ def post_load_weights (self ):
311+ if self .gate_up_proj .has_fp8_qdq :
312+ # For the special gemm+swiglu kernel, we need to set the inverse of the output scale, which is the inverse
313+ # of down_proj's combined input scale.
314+ self .gate_up_proj .inv_output_scale = 1.0 / self .down_proj .input_scale
315+ # For the trtllm-gen gemm+swiglu kernel, we need to set the global scale, which is gate_up_proj's
316+ # combined input scale times inv_output_scale.
317+ self .gate_up_proj .trtllm_gen_global_scale = self .gate_up_proj .combined_scale * self .gate_up_proj .inv_output_scale
318+
327319 def forward (
328320 self ,
329321 x : Union [torch .Tensor , Fp4QuantizedTensor ],
@@ -450,7 +442,6 @@ def __init__(
450442 weight_loading_mode : MoEWeightLoadingMode = MoEWeightLoadingMode .
451443 VANILLA ,
452444 apply_router_weight_on_input : bool = False ,
453- post_load_weights_hook : Optional [Callable ] = None ,
454445 ):
455446
456447 super ().__init__ (
@@ -466,8 +457,6 @@ def __init__(
466457 apply_router_weight_on_input = apply_router_weight_on_input ,
467458 )
468459
469- self .post_load_weights_hook = post_load_weights_hook
470-
471460 # Enable min-latency mode for Llama4 Maverick TP8 EP1.
472461 self .enable_min_latency_fused_moe = False
473462 if num_experts == 128 \
@@ -481,12 +470,6 @@ def __init__(
481470 and apply_router_weight_on_input :
482471 self .enable_min_latency_fused_moe = True
483472
484- def load_weights (self , weights : List [Dict ]):
485- super ().load_weights (weights )
486-
487- if self .post_load_weights_hook :
488- self .post_load_weights_hook (self )
489-
490473 def forward (
491474 self ,
492475 x : Union [torch .Tensor , Fp4QuantizedTensor ],
@@ -560,22 +543,6 @@ def __init__(
560543 overridden_tp_size = 1 if self .enable_attention_dp else None ,
561544 reduce_output = False )
562545
563- def post_load_weights_hook (shared_expert , experts ):
564- # Set min-latency quant scales for routed experts if we plan to use min-latency MoE kernels.
565- # This is because the routed experts' input scale is after the score multiplication, so we must use the
566- # pre-score scaling input scale, which happens to be shared expert's input scale.
567- if experts .enable_min_latency_fused_moe and hasattr (
568- shared_expert .gate_up_proj , "input_scale" ):
569- pre_score_scaling_input_scale = shared_expert .gate_up_proj .input_scale
570- experts .min_latency_quant_scales = FusedMoEQuantScalesFP8 (
571- fc1_dequant = experts .fc31_dequant .data /
572- experts .fc31_input_dequant .data *
573- pre_score_scaling_input_scale ,
574- fc2_quant = experts .fc2_quant ,
575- fc2_dequant = experts .fc2_dequant ,
576- fc1_input_dequant = pre_score_scaling_input_scale ,
577- )
578-
579546 self .experts = Llama4MinLatencyFusedMoE (
580547 routing_method = Llama4RenormalizeMoeRoutingMethod (top_k ),
581548 num_experts = num_experts ,
@@ -587,8 +554,7 @@ def post_load_weights_hook(shared_expert, experts):
587554 weight_loading_mode = MoEWeightLoadingMode .FUSED_GATE_UP_PROJ ,
588555 model_config = model_config ,
589556 apply_router_weight_on_input = True ,
590- post_load_weights_hook = partial (post_load_weights_hook ,
591- self .shared_expert ))
557+ )
592558
593559 self .router = Llama4MinLatencyLinear (
594560 hidden_size ,
@@ -597,6 +563,22 @@ def post_load_weights_hook(shared_expert, experts):
597563 dtype = model_config .pretrained_config .torch_dtype ,
598564 quant_config = None )
599565
566+ def post_load_weights (self ):
567+ # Set min-latency quant scales for routed experts if we plan to use min-latency MoE kernels.
568+ # This is because the routed experts' input scale is after the score multiplication, so we must use the
569+ # pre-score scaling input scale, which happens to be shared expert's input scale.
570+ if self .experts .enable_min_latency_fused_moe and hasattr (
571+ self .shared_expert .gate_up_proj , "input_scale" ):
572+ pre_score_scaling_input_scale = self .shared_expert .gate_up_proj .input_scale
573+ self .experts .min_latency_quant_scales = FusedMoEQuantScalesFP8 (
574+ fc1_dequant = self .experts .fc31_dequant .data /
575+ self .experts .fc31_input_dequant .data *
576+ pre_score_scaling_input_scale ,
577+ fc2_quant = self .experts .fc2_quant ,
578+ fc2_dequant = self .experts .fc2_dequant ,
579+ fc1_input_dequant = pre_score_scaling_input_scale ,
580+ )
581+
600582 def compute_routed_output (
601583 self ,
602584 hidden_states ,
0 commit comments