Skip to content

Commit 9625471

Browse files
kaiyuxzongfeijing
andcommitted
[None] [test] Add MNNVL AlltoAll tests to pre-merge (#7465)
Signed-off-by: Kaiyu Xie <[email protected]> Signed-off-by: Zongfei Jing <[email protected]> Co-authored-by: Zongfei Jing <[email protected]> Signed-off-by: Kaiyu Xie <[email protected]>
1 parent 95eac2c commit 9625471

File tree

5 files changed

+39
-20
lines changed

5 files changed

+39
-20
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from ..modules.embedding import Embedding
6060
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod,
6161
MoEWeightLoadingMode, create_moe)
62+
from ..modules.fused_moe.fused_moe_wide_ep import WideEPMoE
6263
from ..modules.gated_mlp import GatedMLP
6364
from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
6465
from ..modules.multi_stream_utils import maybe_execute_in_parallel
@@ -849,6 +850,9 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
849850
output_dtype=hidden_states.dtype,
850851
all_rank_num_tokens=all_rank_num_tokens,
851852
use_dp_padding=use_dp_padding,
853+
**({
854+
"alltoall_result_do_sum": False
855+
} if isinstance(self.experts, WideEPMoE) else {}),
852856
)
853857

854858
return routed_output

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -419,14 +419,15 @@ def reducescatter_or_allreduce(
419419
return outputs
420420

421421
def forward_chunk(
422-
self,
423-
x: Union[torch.Tensor, Fp4QuantizedTensor],
424-
router_logits: torch.Tensor,
425-
use_all_to_all: bool,
426-
output_dtype: Optional[torch.dtype] = None,
427-
all_rank_num_tokens: Optional[List[int]] = None,
428-
use_dp_padding: Optional[bool] = None,
429-
repeating_info: Tuple = (True, True),
422+
self,
423+
x: Union[torch.Tensor, Fp4QuantizedTensor],
424+
router_logits: torch.Tensor,
425+
use_all_to_all: bool,
426+
output_dtype: Optional[torch.dtype] = None,
427+
all_rank_num_tokens: Optional[List[int]] = None,
428+
use_dp_padding: Optional[bool] = None,
429+
repeating_info: Tuple = (True, True),
430+
alltoall_result_do_sum: bool = True,
430431
) -> torch.Tensor:
431432
all_rank_max_num_tokens = max(all_rank_num_tokens)
432433
if isinstance(x, Fp4QuantizedTensor):
@@ -441,7 +442,7 @@ def forward_chunk(
441442
self.layer_load_balancer.start_wait_gpu_stage()
442443

443444
if not use_all_to_all or self.alltoall_method_type != AlltoallMethodType.MNNVL:
444-
pass
445+
alltoall_result_do_sum = True
445446

446447
weight_dtype = self.w3_w1_weight.dtype
447448

@@ -706,7 +707,8 @@ def forward_chunk(
706707
if self.enable_dummy_allreduce:
707708
self.dummy_allreduce()
708709
final_hidden_states = self.alltoall_combine(
709-
final_hidden_states, alltoall_info, token_count)
710+
final_hidden_states, alltoall_info, token_count,
711+
alltoall_result_do_sum)
710712
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
711713
final_hidden_states = self.unpad_tensors(
712714
padded, final_hidden_states)
@@ -751,6 +753,7 @@ def forward_impl(
751753
output_dtype: Optional[torch.dtype] = None,
752754
all_rank_num_tokens: Optional[List[int]] = None,
753755
use_dp_padding: Optional[bool] = None,
756+
alltoall_result_do_sum: bool = True,
754757
**kwargs,
755758
) -> torch.Tensor:
756759
assert all_rank_num_tokens is not None
@@ -778,7 +781,8 @@ def forward_impl(
778781
output_dtype,
779782
all_rank_num_tokens=all_rank_num_tokens_padded,
780783
use_dp_padding=use_dp_padding,
781-
repeating_info=(is_first_call, is_last_call))
784+
repeating_info=(is_first_call, is_last_call),
785+
alltoall_result_do_sum=alltoall_result_do_sum)
782786
outputs = self.reducescatter_or_allreduce(
783787
outputs,
784788
use_all_to_all,
@@ -836,7 +840,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
836840
all_rank_num_tokens=all_rank_num_tokens_list[
837841
idx_chunk],
838842
use_dp_padding=use_dp_padding,
839-
repeating_info=(is_first_call, is_last_call))
843+
repeating_info=(is_first_call, is_last_call),
844+
alltoall_result_do_sum=alltoall_result_do_sum)
840845
if idx_chunk > 0:
841846
outputs_list[-1] = self.reducescatter_or_allreduce(
842847
outputs_list[-1],
@@ -852,7 +857,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
852857
all_rank_num_tokens=all_rank_num_tokens_list[
853858
idx_chunk],
854859
use_dp_padding=use_dp_padding,
855-
repeating_info=(is_first_call, is_last_call))
860+
repeating_info=(is_first_call, is_last_call),
861+
alltoall_result_do_sum=alltoall_result_do_sum)
856862
with torch.cuda.stream(self.aux_stream):
857863
outputs_list[-1] = self.reducescatter_or_allreduce(
858864
outputs_list[-1],
@@ -866,7 +872,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
866872
router_logits,
867873
use_all_to_all,
868874
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk],
869-
repeating_info=(is_first_call, is_last_call))
875+
repeating_info=(is_first_call, is_last_call),
876+
alltoall_result_do_sum=alltoall_result_do_sum)
870877

871878
outputs_list.append(outputs)
872879
if not use_all_to_all:
@@ -922,7 +929,8 @@ def alltoall_dispatch(self, x: torch.Tensor, x_sf: Optional[torch.Tensor],
922929
return x, x_sf, token_selected_slots, token_final_scales
923930

924931
def alltoall_combine(self, final_hidden_states: torch.Tensor,
925-
alltoall_info: MoEAlltoallInfo, token_count: int):
932+
alltoall_info: MoEAlltoallInfo, token_count: int,
933+
alltoall_result_do_sum: bool):
926934
top_k = self.routing_method.experts_per_token
927935
if isinstance(final_hidden_states, list):
928936
final_hidden_states = final_hidden_states[0]
@@ -935,7 +943,7 @@ def alltoall_combine(self, final_hidden_states: torch.Tensor,
935943
top_k=top_k,
936944
token_count=token_count,
937945
use_low_precision_combine=self.use_low_precision_combine,
938-
do_reduce=False)
946+
do_reduce=alltoall_result_do_sum)
939947

940948
return final_hidden_states
941949

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ l0_dgx_b200:
1616
tests:
1717
- unittest/_torch/multi_gpu_modeling -k "deepseek"
1818
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
19+
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[MNNVL]
1920
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
2021
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
2122
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ l0_dgx_h100:
104104
- unittest/_torch/multi_gpu_modeling/test_deepseek.py::test_deepseek_streaming[tp4-bf16-trtllm-deepseekv3_lite]
105105
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEP]
106106
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEPLowLatency]
107+
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[MNNVL]
107108
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype0]
108109
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype1]
109110
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.W4A8_CUSTOM-dtype0]

tests/unittest/_torch/modules/test_fused_moe.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,14 @@ def per_rank_test_fused_moe_alltoall(job_id):
212212
weights = {}
213213
for expert_id in range(NUM_EXPERTS):
214214
w1_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
215-
dtype=dtype)
215+
dtype=dtype,
216+
device="cuda")
216217
w2_weight = torch.empty((HIDDEN_SIZE, INTERMEDIATE_SIZE),
217-
dtype=dtype)
218+
dtype=dtype,
219+
device="cuda")
218220
w3_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
219-
dtype=dtype)
221+
dtype=dtype,
222+
device="cuda")
220223
torch.nn.init.xavier_uniform_(w1_weight)
221224
torch.nn.init.xavier_uniform_(w2_weight)
222225
torch.nn.init.xavier_uniform_(w3_weight)
@@ -292,7 +295,6 @@ def per_rank_test_fused_moe_alltoall(job_id):
292295
assert r is None
293296

294297

295-
@pytest.mark.skip(reason="https://nvbugs/5467531")
296298
@pytest.mark.skipif(torch.cuda.device_count() < 4,
297299
reason="needs 4 GPUs to run this test")
298300
@pytest.mark.parametrize("alltoall_method_type", [
@@ -302,6 +304,9 @@ def per_rank_test_fused_moe_alltoall(job_id):
302304
ids=lambda s: s.name)
303305
def test_fused_moe_alltoall_fp4(alltoall_method_type):
304306

307+
if alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
308+
pytest.skip("Skipped due to https://nvbugs/5467531")
309+
305310
world_size = 4
306311
dtype = torch.bfloat16
307312
HIDDEN_SIZE = 2560

0 commit comments

Comments
 (0)