Skip to content

Commit 8ad0381

Browse files
kaiyuxzongfeijing
authored andcommitted
[None] [test] Add MNNVL AlltoAll tests to pre-merge (NVIDIA#7466)
Signed-off-by: Kaiyu Xie <[email protected]> Signed-off-by: Zongfei Jing <[email protected]> Co-authored-by: Zongfei Jing <[email protected]> Signed-off-by: Faradawn Yang <[email protected]>
1 parent dddffd3 commit 8ad0381

File tree

6 files changed

+41
-20
lines changed

6 files changed

+41
-20
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from ..modules.embedding import Embedding
6060
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod,
6161
MoEWeightLoadingMode, create_moe)
62+
from ..modules.fused_moe.fused_moe_wide_ep import WideEPMoE
6263
from ..modules.gated_mlp import GatedMLP
6364
from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
6465
from ..modules.multi_stream_utils import maybe_execute_in_parallel
@@ -849,6 +850,9 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
849850
output_dtype=hidden_states.dtype,
850851
all_rank_num_tokens=all_rank_num_tokens,
851852
use_dp_padding=use_dp_padding,
853+
**({
854+
"alltoall_result_do_sum": False
855+
} if isinstance(self.experts, WideEPMoE) else {}),
852856
)
853857

854858
return routed_output

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -452,14 +452,15 @@ def is_post_quant_all2all_supported(self):
452452
return False
453453

454454
def forward_chunk(
455-
self,
456-
x: Union[torch.Tensor, Fp4QuantizedTensor],
457-
router_logits: torch.Tensor,
458-
use_all_to_all: bool,
459-
output_dtype: Optional[torch.dtype] = None,
460-
all_rank_num_tokens: Optional[List[int]] = None,
461-
use_dp_padding: Optional[bool] = None,
462-
repeating_info: Tuple = (True, True),
455+
self,
456+
x: Union[torch.Tensor, Fp4QuantizedTensor],
457+
router_logits: torch.Tensor,
458+
use_all_to_all: bool,
459+
output_dtype: Optional[torch.dtype] = None,
460+
all_rank_num_tokens: Optional[List[int]] = None,
461+
use_dp_padding: Optional[bool] = None,
462+
repeating_info: Tuple = (True, True),
463+
alltoall_result_do_sum: bool = True,
463464
) -> torch.Tensor:
464465
all_rank_max_num_tokens = max(all_rank_num_tokens)
465466
if isinstance(x, Fp4QuantizedTensor):
@@ -474,7 +475,7 @@ def forward_chunk(
474475
self.layer_load_balancer.start_wait_gpu_stage()
475476

476477
if not use_all_to_all or self.alltoall_method_type != AlltoallMethodType.MNNVL:
477-
pass
478+
alltoall_result_do_sum = True
478479

479480
weight_dtype = self.w3_w1_weight.dtype
480481

@@ -741,7 +742,8 @@ def forward_chunk(
741742
if self.enable_dummy_allreduce:
742743
self.dummy_allreduce()
743744
final_hidden_states = self.alltoall_combine(
744-
final_hidden_states, alltoall_info, token_count)
745+
final_hidden_states, alltoall_info, token_count,
746+
alltoall_result_do_sum)
745747
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
746748
final_hidden_states = self.unpad_tensors(
747749
padded, final_hidden_states)
@@ -786,6 +788,7 @@ def forward_impl(
786788
output_dtype: Optional[torch.dtype] = None,
787789
all_rank_num_tokens: Optional[List[int]] = None,
788790
use_dp_padding: Optional[bool] = None,
791+
alltoall_result_do_sum: bool = True,
789792
**kwargs,
790793
) -> torch.Tensor:
791794
assert all_rank_num_tokens is not None
@@ -813,7 +816,8 @@ def forward_impl(
813816
output_dtype,
814817
all_rank_num_tokens=all_rank_num_tokens_padded,
815818
use_dp_padding=use_dp_padding,
816-
repeating_info=(is_first_call, is_last_call))
819+
repeating_info=(is_first_call, is_last_call),
820+
alltoall_result_do_sum=alltoall_result_do_sum)
817821
outputs = self.reducescatter_or_allreduce(
818822
outputs,
819823
use_all_to_all,
@@ -871,7 +875,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
871875
all_rank_num_tokens=all_rank_num_tokens_list[
872876
idx_chunk],
873877
use_dp_padding=use_dp_padding,
874-
repeating_info=(is_first_call, is_last_call))
878+
repeating_info=(is_first_call, is_last_call),
879+
alltoall_result_do_sum=alltoall_result_do_sum)
875880
if idx_chunk > 0:
876881
outputs_list[-1] = self.reducescatter_or_allreduce(
877882
outputs_list[-1],
@@ -887,7 +892,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
887892
all_rank_num_tokens=all_rank_num_tokens_list[
888893
idx_chunk],
889894
use_dp_padding=use_dp_padding,
890-
repeating_info=(is_first_call, is_last_call))
895+
repeating_info=(is_first_call, is_last_call),
896+
alltoall_result_do_sum=alltoall_result_do_sum)
891897
with torch.cuda.stream(self.aux_stream):
892898
outputs_list[-1] = self.reducescatter_or_allreduce(
893899
outputs_list[-1],
@@ -901,7 +907,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
901907
router_logits,
902908
use_all_to_all,
903909
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk],
904-
repeating_info=(is_first_call, is_last_call))
910+
repeating_info=(is_first_call, is_last_call),
911+
alltoall_result_do_sum=alltoall_result_do_sum)
905912

906913
outputs_list.append(outputs)
907914
if not use_all_to_all:
@@ -957,7 +964,8 @@ def alltoall_dispatch(self, x: torch.Tensor, x_sf: Optional[torch.Tensor],
957964
return x, x_sf, token_selected_slots, token_final_scales
958965

959966
def alltoall_combine(self, final_hidden_states: torch.Tensor,
960-
alltoall_info: MoEAlltoallInfo, token_count: int):
967+
alltoall_info: MoEAlltoallInfo, token_count: int,
968+
alltoall_result_do_sum: bool):
961969
top_k = self.routing_method.experts_per_token
962970
if isinstance(final_hidden_states, list):
963971
final_hidden_states = final_hidden_states[0]
@@ -970,7 +978,7 @@ def alltoall_combine(self, final_hidden_states: torch.Tensor,
970978
top_k=top_k,
971979
token_count=token_count,
972980
use_low_precision_combine=self.use_low_precision_combine,
973-
do_reduce=False)
981+
do_reduce=alltoall_result_do_sum)
974982

975983
return final_hidden_states
976984

tensorrt_llm/_torch/modules/fused_moe/interface.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def forward(
242242
output_dtype: Optional[torch.dtype] = None,
243243
all_rank_num_tokens: Optional[List[int]] = None,
244244
use_dp_padding: Optional[bool] = None,
245+
**kwargs,
245246
) -> Union[torch.Tensor, List[torch.Tensor]]:
246247
if self.register_to_config and is_torch_compiling():
247248
hidden_states = x.fp4_tensor if isinstance(
@@ -274,6 +275,7 @@ def forward(
274275
output_dtype=output_dtype,
275276
all_rank_num_tokens=all_rank_num_tokens,
276277
use_dp_padding=use_dp_padding,
278+
**kwargs,
277279
)
278280

279281
@property

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ l0_dgx_b200:
1616
tests:
1717
- unittest/_torch/multi_gpu_modeling -k "deepseek"
1818
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
19+
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[MNNVL]
1920
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
2021
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
2122
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ l0_dgx_h100:
104104
- unittest/_torch/multi_gpu_modeling/test_deepseek.py::test_deepseek_streaming[tp4-bf16-trtllm-deepseekv3_lite]
105105
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEP]
106106
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEPLowLatency]
107+
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[MNNVL]
107108
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype0]
108109
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype1]
109110
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.W4A8_CUSTOM-dtype0]

tests/unittest/_torch/modules/test_fused_moe.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,11 +213,14 @@ def per_rank_test_fused_moe_alltoall(job_id):
213213
weights = {}
214214
for expert_id in range(NUM_EXPERTS):
215215
w1_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
216-
dtype=dtype)
216+
dtype=dtype,
217+
device="cuda")
217218
w2_weight = torch.empty((HIDDEN_SIZE, INTERMEDIATE_SIZE),
218-
dtype=dtype)
219+
dtype=dtype,
220+
device="cuda")
219221
w3_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
220-
dtype=dtype)
222+
dtype=dtype,
223+
device="cuda")
221224
torch.nn.init.xavier_uniform_(w1_weight)
222225
torch.nn.init.xavier_uniform_(w2_weight)
223226
torch.nn.init.xavier_uniform_(w3_weight)
@@ -293,7 +296,6 @@ def per_rank_test_fused_moe_alltoall(job_id):
293296
assert r is None
294297

295298

296-
@pytest.mark.skip(reason="https://nvbugs/5467531")
297299
@pytest.mark.skipif(torch.cuda.device_count() < 4,
298300
reason="needs 4 GPUs to run this test")
299301
@pytest.mark.parametrize("alltoall_method_type", [
@@ -303,6 +305,9 @@ def per_rank_test_fused_moe_alltoall(job_id):
303305
ids=lambda s: s.name)
304306
def test_fused_moe_alltoall_fp4(alltoall_method_type):
305307

308+
if alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
309+
pytest.skip("Skipped due to https://nvbugs/5467531")
310+
306311
world_size = 4
307312
dtype = torch.bfloat16
308313
HIDDEN_SIZE = 2560

0 commit comments

Comments
 (0)