Skip to content

Commit 42c2ec3

Browse files
authored
[https://nvbugs/5473781][fix] Fix llama 4 FP8 for PP>1 (#7220)
Signed-off-by: Mike Iovine <[email protected]>
1 parent b1dc84b commit 42c2ec3

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -419,11 +419,12 @@ def __init__(
419419
overridden_tp_size=1 if self.enable_attention_dp else None,
420420
layer_idx=layer_idx,
421421
)
422-
422+
# TODO(TRTLLM-7809): Fix fusion with PP>1
423423
self.fusion_config.PRE_MLP_FUSION = model_config.mapping.has_tp(
424-
) and not self.enable_attention_dp and self.enable_fusion
425-
self.fusion_config.POST_MLP_FUSION = model_config.mapping.has_tp(
426-
) and not self.enable_attention_dp and self.enable_fusion
424+
) and not self.enable_attention_dp and self.enable_fusion and not model_config.mapping.has_pp(
425+
)
426+
self.fusion_config.POST_MLP_FUSION = self.fusion_config.PRE_MLP_FUSION
427+
427428
else:
428429
self.feed_forward = Llama4MoE(
429430
num_experts=config.num_local_experts,
@@ -437,9 +438,9 @@ def __init__(
437438
layer_idx=layer_idx)
438439

439440
self.fusion_config.PRE_MOE_FUSION = model_config.mapping.has_tp(
440-
) and not self.enable_attention_dp and self.enable_fusion
441-
self.fusion_config.POST_MOE_FUSION = model_config.mapping.has_tp(
442-
) and not self.enable_attention_dp and self.enable_fusion
441+
) and not self.enable_attention_dp and self.enable_fusion and not model_config.mapping.has_pp(
442+
)
443+
self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION
443444

444445
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
445446
eps=config.rms_norm_eps,

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -698,8 +698,8 @@ def test_chunked_prefill(self, attn_backend):
698698
@parametrize_with_ids("cuda_graph", [False, True])
699699
@pytest.mark.parametrize(
700700
"tp_size,pp_size,ep_size", [(8, 1, 1), (8, 1, 4), (8, 1, 8), (4, 1, 1),
701-
(4, 1, 2), (4, 1, 4)],
702-
ids=["tp8", "tp8ep4", "tp8ep8", "tp4", "tp4ep2", "tp4ep4"])
701+
(4, 1, 2), (4, 1, 4), (4, 2, 1)],
702+
ids=["tp8", "tp8ep4", "tp8ep8", "tp4", "tp4ep2", "tp4ep4", "tp4pp2"])
703703
def test_fp8(self, cuda_graph, tp_size, pp_size, ep_size):
704704
if get_device_memory() < 140000 and get_device_count() < 8:
705705
pytest.skip("Not enough memory for this test")

0 commit comments

Comments
 (0)