Skip to content

Commit 60ea9a3

Browse files
hyukndominicshanshan
authored andcommitted
[None][fix] Complete the last missing allreduce op in Llama3/4. (NVIDIA#6850)
The allreduce op of the last decoder layer is missing in some circumstances for the models Llama3 and Llama4. Signed-off-by: Yukun He <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
1 parent e33c37b commit 60ea9a3

File tree

1 file changed

+83
-57
lines changed

1 file changed

+83
-57
lines changed

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 83 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -554,50 +554,60 @@ def forward(
554554
hidden_states, residual)
555555

556556
if (self.fusion_config.POST_MOE_FUSION
557-
or self.fusion_config.POST_MLP_FUSION
558-
) and self.next_layer_layernorm is not None:
559-
# Get the scale for the next allreduce fusion op
560-
if self.next_attn is not None and (self.is_nvfp4
561-
or self.is_fp8_quant):
562-
scale = self.next_attn.qkv_proj.input_scale
563-
else:
564-
# Add just the fusion op to RESIDUAL_RMS_NORM due to this is the last decoder layer
565-
self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
566-
scale = None
567-
568-
# TODO: MIN_LATENCY_MODE is hardcoded to False
569-
if cutlass_min_latency_mode:
570-
shared_output = hidden_states[0]
571-
hidden_states_activated_experts = hidden_states[1]
572-
num_activated_experts_per_node = hidden_states[2]
573-
experts_to_token_score = hidden_states[3]
574-
575-
allreduce_output = self.moe_allreduce(
576-
residual,
577-
self.next_layer_layernorm.weight,
578-
device_num_experts=num_activated_experts_per_node,
579-
scale_input=experts_to_token_score,
580-
active_experts_token_input=hidden_states_activated_experts,
581-
token_input=shared_output,
582-
eps=self.next_layer_layernorm.variance_epsilon,
583-
)
584-
else:
585-
allreduce_output = self.all_reduce(
557+
or self.fusion_config.POST_MLP_FUSION):
558+
# If there is no extra layernorm, do another pure allreduce because
559+
# the allreduce in feed-forward module has been disabled.
560+
if self.next_layer_layernorm is None:
561+
hidden_states, residual = self.all_reduce(
586562
hidden_states,
587563
all_reduce_params=AllReduceParams(
588-
fusion_op=self.post_feed_forward_fusion_op,
564+
fusion_op=None,
589565
residual=residual,
590-
norm_weight=self.next_layer_layernorm.weight,
591-
scale=scale,
592-
eps=self.next_layer_layernorm.variance_epsilon,
593566
))
594-
595-
# Unpack the allreduce output
596-
if self.next_attn is not None and self.is_nvfp4:
597-
act_fp4, act_sf, residual = allreduce_output
598-
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
599567
else:
600-
hidden_states, residual = allreduce_output
568+
# The next layernorm exists but it could be the last decoder layer.
569+
# Adjust the scale and fusion pattern.
570+
if self.next_attn is not None and (self.is_nvfp4
571+
or self.is_fp8_quant):
572+
scale = self.next_attn.qkv_proj.input_scale
573+
else:
574+
self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
575+
scale = None
576+
577+
# TODO: MIN_LATENCY_MODE is hardcoded to False
578+
if cutlass_min_latency_mode:
579+
shared_output = hidden_states[0]
580+
hidden_states_activated_experts = hidden_states[1]
581+
num_activated_experts_per_node = hidden_states[2]
582+
experts_to_token_score = hidden_states[3]
583+
584+
allreduce_output = self.moe_allreduce(
585+
residual,
586+
self.next_layer_layernorm.weight,
587+
device_num_experts=num_activated_experts_per_node,
588+
scale_input=experts_to_token_score,
589+
active_experts_token_input=
590+
hidden_states_activated_experts,
591+
token_input=shared_output,
592+
eps=self.next_layer_layernorm.variance_epsilon,
593+
)
594+
else:
595+
allreduce_output = self.all_reduce(
596+
hidden_states,
597+
all_reduce_params=AllReduceParams(
598+
fusion_op=self.post_feed_forward_fusion_op,
599+
residual=residual,
600+
norm_weight=self.next_layer_layernorm.weight,
601+
scale=scale,
602+
eps=self.next_layer_layernorm.variance_epsilon,
603+
))
604+
605+
# Unpack the allreduce output
606+
if self.next_attn is not None and self.is_nvfp4:
607+
act_fp4, act_sf, residual = allreduce_output
608+
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
609+
else:
610+
hidden_states, residual = allreduce_output
601611
elif self.next_layer_layernorm:
602612
hidden_states, residual = self.next_layer_layernorm(
603613
hidden_states, residual)
@@ -710,6 +720,7 @@ def forward(
710720
scale = self.mlp.gate_up_proj.input_scale
711721
else:
712722
scale = None
723+
713724
all_reduce_output = self.all_reduce(
714725
hidden_states,
715726
all_reduce_params=AllReduceParams(
@@ -752,25 +763,40 @@ def forward(
752763

753764
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
754765
hidden_states, residual)
755-
if self.POST_MLP_FUSION and self.next_attn is not None:
756-
if self.is_nvfp4 or self.is_fp8_quant:
757-
scale = self.next_attn.qkv_proj.input_scale
758-
else:
759-
scale = None
760-
all_reduce_output = self.all_reduce(
761-
hidden_states,
762-
all_reduce_params=AllReduceParams(
763-
fusion_op=self.post_mlp_fusion_op,
764-
residual=residual,
765-
norm_weight=self.next_layer_layernorm.weight,
766-
scale=scale,
767-
eps=self.next_layer_layernorm.variance_epsilon,
768-
))
769-
if self.is_nvfp4:
770-
act_fp4, act_sf, residual = all_reduce_output
771-
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
766+
767+
if self.POST_MLP_FUSION:
768+
# If there is no extra layernorm, do another pure allreduce.
769+
if self.next_layer_layernorm is None:
770+
hidden_states, residual = self.all_reduce(
771+
hidden_states,
772+
all_reduce_params=AllReduceParams(
773+
fusion_op=None,
774+
residual=residual,
775+
))
772776
else:
773-
hidden_states, residual = all_reduce_output
777+
# The next layernorm exists but it could be the last decoder layer.
778+
# Adjust the scale and fusion pattern.
779+
if self.next_attn is not None and (self.is_nvfp4
780+
or self.is_fp8_quant):
781+
scale = self.next_attn.qkv_proj.input_scale
782+
else:
783+
self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
784+
scale = None
785+
786+
all_reduce_output = self.all_reduce(
787+
hidden_states,
788+
all_reduce_params=AllReduceParams(
789+
fusion_op=self.post_mlp_fusion_op,
790+
residual=residual,
791+
norm_weight=self.next_layer_layernorm.weight,
792+
scale=scale,
793+
eps=self.next_layer_layernorm.variance_epsilon,
794+
))
795+
if self.next_attn is not None and self.is_nvfp4:
796+
act_fp4, act_sf, residual = all_reduce_output
797+
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
798+
else:
799+
hidden_states, residual = all_reduce_output
774800
elif self.next_layer_layernorm:
775801
hidden_states, residual = self.next_layer_layernorm(
776802
hidden_states, residual)

0 commit comments

Comments
 (0)