@@ -554,50 +554,60 @@ def forward(
554554 hidden_states , residual )
555555
556556 if (self .fusion_config .POST_MOE_FUSION
557- or self .fusion_config .POST_MLP_FUSION
558- ) and self .next_layer_layernorm is not None :
559- # Get the scale for the next allreduce fusion op
560- if self .next_attn is not None and (self .is_nvfp4
561- or self .is_fp8_quant ):
562- scale = self .next_attn .qkv_proj .input_scale
563- else :
564- # Add just the fusion op to RESIDUAL_RMS_NORM due to this is the last decoder layer
565- self .post_feed_forward_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM
566- scale = None
567-
568- # TODO: MIN_LATENCY_MODE is hardcoded to False
569- if cutlass_min_latency_mode :
570- shared_output = hidden_states [0 ]
571- hidden_states_activated_experts = hidden_states [1 ]
572- num_activated_experts_per_node = hidden_states [2 ]
573- experts_to_token_score = hidden_states [3 ]
574-
575- allreduce_output = self .moe_allreduce (
576- residual ,
577- self .next_layer_layernorm .weight ,
578- device_num_experts = num_activated_experts_per_node ,
579- scale_input = experts_to_token_score ,
580- active_experts_token_input = hidden_states_activated_experts ,
581- token_input = shared_output ,
582- eps = self .next_layer_layernorm .variance_epsilon ,
583- )
584- else :
585- allreduce_output = self .all_reduce (
557+ or self .fusion_config .POST_MLP_FUSION ):
558+ # If there is no extra layernorm, do another pure allreduce because
559+ # the allreduce in feed-forward module has been disabled.
560+ if self .next_layer_layernorm is None :
561+ hidden_states , residual = self .all_reduce (
586562 hidden_states ,
587563 all_reduce_params = AllReduceParams (
588- fusion_op = self . post_feed_forward_fusion_op ,
564+ fusion_op = None ,
589565 residual = residual ,
590- norm_weight = self .next_layer_layernorm .weight ,
591- scale = scale ,
592- eps = self .next_layer_layernorm .variance_epsilon ,
593566 ))
594-
595- # Unpack the allreduce output
596- if self .next_attn is not None and self .is_nvfp4 :
597- act_fp4 , act_sf , residual = allreduce_output
598- hidden_states = Fp4QuantizedTensor (act_fp4 , act_sf )
599567 else :
600- hidden_states , residual = allreduce_output
568+ # The next layernorm exists but it could be the last decoder layer.
569+ # Adjust the scale and fusion pattern.
570+ if self .next_attn is not None and (self .is_nvfp4
571+ or self .is_fp8_quant ):
572+ scale = self .next_attn .qkv_proj .input_scale
573+ else :
574+ self .post_feed_forward_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM
575+ scale = None
576+
577+ # TODO: MIN_LATENCY_MODE is hardcoded to False
578+ if cutlass_min_latency_mode :
579+ shared_output = hidden_states [0 ]
580+ hidden_states_activated_experts = hidden_states [1 ]
581+ num_activated_experts_per_node = hidden_states [2 ]
582+ experts_to_token_score = hidden_states [3 ]
583+
584+ allreduce_output = self .moe_allreduce (
585+ residual ,
586+ self .next_layer_layernorm .weight ,
587+ device_num_experts = num_activated_experts_per_node ,
588+ scale_input = experts_to_token_score ,
589+ active_experts_token_input =
590+ hidden_states_activated_experts ,
591+ token_input = shared_output ,
592+ eps = self .next_layer_layernorm .variance_epsilon ,
593+ )
594+ else :
595+ allreduce_output = self .all_reduce (
596+ hidden_states ,
597+ all_reduce_params = AllReduceParams (
598+ fusion_op = self .post_feed_forward_fusion_op ,
599+ residual = residual ,
600+ norm_weight = self .next_layer_layernorm .weight ,
601+ scale = scale ,
602+ eps = self .next_layer_layernorm .variance_epsilon ,
603+ ))
604+
605+ # Unpack the allreduce output
606+ if self .next_attn is not None and self .is_nvfp4 :
607+ act_fp4 , act_sf , residual = allreduce_output
608+ hidden_states = Fp4QuantizedTensor (act_fp4 , act_sf )
609+ else :
610+ hidden_states , residual = allreduce_output
601611 elif self .next_layer_layernorm :
602612 hidden_states , residual = self .next_layer_layernorm (
603613 hidden_states , residual )
@@ -710,6 +720,7 @@ def forward(
710720 scale = self .mlp .gate_up_proj .input_scale
711721 else :
712722 scale = None
723+
713724 all_reduce_output = self .all_reduce (
714725 hidden_states ,
715726 all_reduce_params = AllReduceParams (
@@ -752,25 +763,40 @@ def forward(
752763
753764 spec_metadata .maybe_capture_hidden_states (self .layer_idx ,
754765 hidden_states , residual )
755- if self .POST_MLP_FUSION and self .next_attn is not None :
756- if self .is_nvfp4 or self .is_fp8_quant :
757- scale = self .next_attn .qkv_proj .input_scale
758- else :
759- scale = None
760- all_reduce_output = self .all_reduce (
761- hidden_states ,
762- all_reduce_params = AllReduceParams (
763- fusion_op = self .post_mlp_fusion_op ,
764- residual = residual ,
765- norm_weight = self .next_layer_layernorm .weight ,
766- scale = scale ,
767- eps = self .next_layer_layernorm .variance_epsilon ,
768- ))
769- if self .is_nvfp4 :
770- act_fp4 , act_sf , residual = all_reduce_output
771- hidden_states = Fp4QuantizedTensor (act_fp4 , act_sf )
766+
767+ if self .POST_MLP_FUSION :
768+ # If there is no extra layernorm, do another pure allreduce.
769+ if self .next_layer_layernorm is None :
770+ hidden_states , residual = self .all_reduce (
771+ hidden_states ,
772+ all_reduce_params = AllReduceParams (
773+ fusion_op = None ,
774+ residual = residual ,
775+ ))
772776 else :
773- hidden_states , residual = all_reduce_output
777+ # The next layernorm exists but it could be the last decoder layer.
778+ # Adjust the scale and fusion pattern.
779+ if self .next_attn is not None and (self .is_nvfp4
780+ or self .is_fp8_quant ):
781+ scale = self .next_attn .qkv_proj .input_scale
782+ else :
783+ self .post_mlp_fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM
784+ scale = None
785+
786+ all_reduce_output = self .all_reduce (
787+ hidden_states ,
788+ all_reduce_params = AllReduceParams (
789+ fusion_op = self .post_mlp_fusion_op ,
790+ residual = residual ,
791+ norm_weight = self .next_layer_layernorm .weight ,
792+ scale = scale ,
793+ eps = self .next_layer_layernorm .variance_epsilon ,
794+ ))
795+ if self .next_attn is not None and self .is_nvfp4 :
796+ act_fp4 , act_sf , residual = all_reduce_output
797+ hidden_states = Fp4QuantizedTensor (act_fp4 , act_sf )
798+ else :
799+ hidden_states , residual = all_reduce_output
774800 elif self .next_layer_layernorm :
775801 hidden_states , residual = self .next_layer_layernorm (
776802 hidden_states , residual )
0 commit comments