@@ -192,8 +192,12 @@ def __init__(
192
192
model_config .mapping )
193
193
self .deep_ep_buffer .reserve (hidden_size , dtype )
194
194
elif self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency :
195
- self .deep_ep_max_num_tokens = min (model_config .max_num_tokens ,
196
- self .moe_max_num_tokens )
195
+ self .deep_ep_max_num_tokens = int (
196
+ os .environ .get (
197
+ "TRTLLM_DEEP_EP_TOKEN_LIMIT" ,
198
+ str (
199
+ min (model_config .max_num_tokens ,
200
+ self .moe_max_num_tokens ))))
197
201
self .deep_ep_buffer = buffer_pool .get_low_latency_buffer (
198
202
model_config .mapping )
199
203
self .deep_ep_buffer .reserve (self .deep_ep_max_num_tokens ,
@@ -274,6 +278,25 @@ def enable_alltoall(self):
274
278
"""
275
279
return self .alltoall_method_type != AlltoallMethodType .NotEnabled
276
280
281
+ def calculate_num_chunks (self , all_rank_num_tokens : List [int ]) -> int :
282
+ num_rows = sum (all_rank_num_tokens )
283
+ return (num_rows + self .moe_max_num_tokens -
284
+ 1 ) // self .moe_max_num_tokens
285
+
286
+ def can_use_alltoall (self , input , all_rank_num_tokens ):
287
+ # Disable alltoall when chunking is used
288
+ if self .calculate_num_chunks (all_rank_num_tokens ) > 1 :
289
+ return False
290
+
291
+ num_tokens = input .shape [0 ]
292
+
293
+ # For DeepEPLowLatency, check if tokens exceed the threshold
294
+ if (self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency
295
+ and num_tokens > self .deep_ep_max_num_tokens ):
296
+ return False
297
+
298
+ return self .enable_alltoall
299
+
277
300
def _get_quant_method (self ):
278
301
if self .quant_config is not None and self .quant_config .layer_quant_mode .has_any_quant (
279
302
exclude_kv_cache = True ):
@@ -316,11 +339,12 @@ def dummy_allreduce(self):
316
339
def reducescatter_or_allreduce (
317
340
self ,
318
341
inputs ,
342
+ use_all_to_all : bool ,
319
343
all_rank_num_tokens : Optional [List [int ]] = None ,
320
344
use_dp_padding : Optional [bool ] = None ,
321
345
):
322
346
outputs = inputs
323
- if not self . enable_alltoall :
347
+ if not use_all_to_all :
324
348
if self .enable_dummy_allreduce :
325
349
self .dummy_allreduce ()
326
350
outputs = reducescatter (
@@ -334,6 +358,7 @@ def forward_chunk(
334
358
self ,
335
359
x : Union [torch .Tensor , Fp4QuantizedTensor ],
336
360
router_logits : torch .Tensor ,
361
+ use_all_to_all : bool ,
337
362
output_dtype : Optional [torch .dtype ] = None ,
338
363
all_rank_num_tokens : Optional [List [int ]] = None ,
339
364
all_rank_max_num_tokens : Optional [int ] = None ,
@@ -382,7 +407,7 @@ def forward_chunk(
382
407
) and is_first_call :
383
408
self .layer_load_balancer .maybe_cudagraph_done_wait ()
384
409
385
- use_allgather = not self . enable_alltoall
410
+ use_allgather = not use_all_to_all
386
411
387
412
loadbalancer_local_statistic_info = None
388
413
gathered_loadbalancer_local_statistic_info = None
@@ -391,7 +416,7 @@ def forward_chunk(
391
416
token_selected_slots = token_selected_experts
392
417
else :
393
418
if not self .layer_load_balancer .is_static_routing (
394
- ) and self . enable_alltoall :
419
+ ) and use_all_to_all :
395
420
self .layer_load_balancer .local_statistic (
396
421
token_selected_experts ,
397
422
is_first_stage = is_first_call ,
@@ -400,7 +425,7 @@ def forward_chunk(
400
425
token_selected_experts , self .use_dp )
401
426
if not self .layer_load_balancer .is_static_routing ():
402
427
# split into two part to get possible overlap with load balancer routing
403
- if self . enable_alltoall :
428
+ if use_all_to_all :
404
429
if is_last_call :
405
430
loadbalancer_local_statistic_info = self .layer_load_balancer .get_local_statistic_tensor (
406
431
)
@@ -412,7 +437,9 @@ def forward_chunk(
412
437
ExpertStatistic .set_layer (self .layer_idx )
413
438
ExpertStatistic .maybe_add_info (self .num_slots , token_selected_slots )
414
439
415
- if self .enable_alltoall :
440
+ # If alltoall is disabled, we need also disable use_postquant_alltoall
441
+ use_postquant_alltoall = self .use_postquant_alltoall and use_all_to_all
442
+ if use_all_to_all :
416
443
if self .alltoall_method_type == AlltoallMethodType .MNNVL :
417
444
if self .enable_dummy_allreduce :
418
445
self .dummy_allreduce ()
@@ -423,15 +450,16 @@ def forward_chunk(
423
450
x ,
424
451
token_selected_slots ,
425
452
token_final_scales ,
453
+ use_postquant_alltoall ,
426
454
loadbalancer_local_statistic_info )
427
455
elif self .alltoall_method_type == AlltoallMethodType .DeepEP :
428
- if not self . use_postquant_alltoall :
456
+ if not use_postquant_alltoall :
429
457
x , recv_topk_idx , token_final_scales , num_recv_tokens_per_expert_list , deep_ep_handle = \
430
458
self .deep_ep_buffer .dispatch (x , token_selected_slots .to (torch .int64 ), token_final_scales , self .num_slots )
431
459
padded , x , _ , recv_topk_idx , token_final_scales = self .pad_empty_recv_tensors (
432
460
x , None , recv_topk_idx , token_final_scales )
433
461
elif self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency :
434
- if not self . use_postquant_alltoall :
462
+ if not use_postquant_alltoall :
435
463
deep_ep_topk_idx = token_selected_slots .to (torch .int64 )
436
464
deep_ep_topk_weights = token_final_scales
437
465
x , recv_expert_count , deep_ep_handle = \
@@ -471,7 +499,7 @@ def forward_chunk(
471
499
x , _ = torch .ops .tensorrt_llm .static_quantize_e4m3_per_tensor (
472
500
x , self .fc31_input_dequant )
473
501
elif self .has_nvfp4 :
474
- if use_allgather or self . use_postquant_alltoall :
502
+ if use_allgather or use_postquant_alltoall :
475
503
if isinstance (x , Fp4QuantizedTensor ):
476
504
if use_allgather :
477
505
assert not x .is_sf_swizzled , "Fp4QuantizedTensor should not be swizzled before allgather"
@@ -527,7 +555,7 @@ def forward_chunk(
527
555
528
556
if self .layer_load_balancer and not self .layer_load_balancer .is_static_routing (
529
557
):
530
- if self . enable_alltoall :
558
+ if use_all_to_all :
531
559
if is_last_call :
532
560
gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info .view (
533
561
(self .mapping .moe_ep_size , self .num_experts ))
@@ -547,7 +575,7 @@ def forward_chunk(
547
575
cluster_rank = self .cluster_rank
548
576
quant_scales = self .quant_scales
549
577
550
- if self . use_postquant_alltoall :
578
+ if use_postquant_alltoall :
551
579
if x_sf is not None and self .has_nvfp4 :
552
580
assert not x_is_sf_swizzled , "Fp4 scaling factor should not be swizzled before Alltoall"
553
581
if self .alltoall_method_type == AlltoallMethodType .MNNVL :
@@ -640,7 +668,7 @@ def forward_chunk(
640
668
f"Not available alltoall method type: { self .alltoall_method_type !r} "
641
669
)
642
670
643
- if self . enable_alltoall :
671
+ if use_all_to_all :
644
672
# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
645
673
# TODO: remove the adapter by changing APIs
646
674
if self .alltoall_method_type == AlltoallMethodType .DeepEP :
@@ -666,7 +694,7 @@ def forward_chunk(
666
694
ep_rank = ep_rank ,
667
695
cluster_size = cluster_size ,
668
696
cluster_rank = cluster_rank ,
669
- enable_alltoall = self . enable_alltoall ,
697
+ enable_alltoall = use_all_to_all ,
670
698
use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale ,
671
699
use_w4a8_group_scaling = use_w4a8_group_scaling ,
672
700
min_latency_mode = False ,
@@ -681,7 +709,7 @@ def forward_chunk(
681
709
# Otherwise, the output should be unpacked as a single tensor.
682
710
final_hidden_states = final_hidden_states [0 ]
683
711
684
- if self . enable_alltoall :
712
+ if use_all_to_all :
685
713
if self .alltoall_method_type == AlltoallMethodType .MNNVL :
686
714
if self .enable_dummy_allreduce :
687
715
self .dummy_allreduce ()
@@ -737,11 +765,10 @@ def forward(
737
765
) -> torch .Tensor :
738
766
assert all_rank_num_tokens is not None
739
767
assert use_dp_padding is not None
740
- num_rows = sum (all_rank_num_tokens )
741
768
742
769
# in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
743
- num_chunks = ( num_rows + self .moe_max_num_tokens -
744
- 1 ) // self .moe_max_num_tokens
770
+ num_chunks = self .calculate_num_chunks ( all_rank_num_tokens )
771
+ use_all_to_all = self .can_use_alltoall ( x , all_rank_num_tokens )
745
772
746
773
if use_dp_padding :
747
774
all_rank_num_tokens_padded = [all_rank_max_num_tokens
@@ -754,13 +781,15 @@ def forward(
754
781
outputs = self .forward_chunk (
755
782
x ,
756
783
router_logits ,
784
+ use_all_to_all ,
757
785
output_dtype ,
758
786
all_rank_num_tokens = all_rank_num_tokens_padded ,
759
787
all_rank_max_num_tokens = all_rank_max_num_tokens ,
760
788
use_dp_padding = use_dp_padding ,
761
789
repeating_info = (is_first_call , is_last_call ))
762
790
outputs = self .reducescatter_or_allreduce (
763
791
outputs ,
792
+ use_all_to_all ,
764
793
all_rank_num_tokens = all_rank_num_tokens_padded ,
765
794
use_dp_padding = use_dp_padding )
766
795
else :
@@ -782,7 +811,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
782
811
all_rank_max_num_tokens_list = split_chunk (all_rank_max_num_tokens ,
783
812
num_chunks )
784
813
chunk_size_list = all_rank_chunk_size_list [self .rank ]
785
- if self . enable_alltoall :
814
+ if use_all_to_all :
786
815
all_rank_num_tokens_list = [[
787
816
1 if val == 0 else val for val in val_list
788
817
] for val_list in all_rank_num_tokens_list ]
@@ -794,7 +823,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
794
823
x_list = x .split (chunk_size_list )
795
824
router_logits_list = router_logits .split (chunk_size_list )
796
825
797
- if not self . enable_alltoall :
826
+ if not use_all_to_all :
798
827
self .event_dict [EventType .Main ].record ()
799
828
with torch .cuda .stream (self .aux_stream ):
800
829
self .event_dict [EventType .Main ].wait ()
@@ -805,12 +834,13 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
805
834
zip (x_list , router_logits_list )):
806
835
is_first_call = idx_chunk == 0 and self .repeat_idx == 0
807
836
is_last_call = idx_chunk == num_chunks - 1 and self .repeat_idx == self .repeat_count - 1
808
- if not self . enable_alltoall :
837
+ if not use_all_to_all :
809
838
if idx_chunk % 2 == 0 :
810
839
with torch .cuda .stream (self .aux_stream ):
811
840
outputs = self .forward_chunk (
812
841
x ,
813
842
router_logits ,
843
+ use_all_to_all ,
814
844
all_rank_num_tokens = all_rank_num_tokens_list [
815
845
idx_chunk ],
816
846
all_rank_max_num_tokens =
@@ -820,13 +850,15 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
820
850
if idx_chunk > 0 :
821
851
outputs_list [- 1 ] = self .reducescatter_or_allreduce (
822
852
outputs_list [- 1 ],
853
+ use_all_to_all ,
823
854
all_rank_num_tokens = all_rank_num_tokens_list [
824
855
idx_chunk - 1 ],
825
856
use_dp_padding = use_dp_padding )
826
857
else :
827
858
outputs = self .forward_chunk (
828
859
x ,
829
860
router_logits ,
861
+ use_all_to_all ,
830
862
all_rank_num_tokens = all_rank_num_tokens_list [
831
863
idx_chunk ],
832
864
all_rank_max_num_tokens = all_rank_max_num_tokens_list [
@@ -836,29 +868,33 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
836
868
with torch .cuda .stream (self .aux_stream ):
837
869
outputs_list [- 1 ] = self .reducescatter_or_allreduce (
838
870
outputs_list [- 1 ],
871
+ use_all_to_all ,
839
872
all_rank_num_tokens = all_rank_num_tokens_list [
840
873
idx_chunk - 1 ],
841
874
use_dp_padding = use_dp_padding )
842
875
else :
843
876
outputs = self .forward_chunk (
844
877
x ,
845
878
router_logits ,
879
+ use_all_to_all ,
846
880
all_rank_num_tokens = all_rank_num_tokens_list [idx_chunk ],
847
881
all_rank_max_num_tokens = all_rank_max_num_tokens_list [
848
882
idx_chunk ],
849
883
repeating_info = (is_first_call , is_last_call ))
850
884
851
885
outputs_list .append (outputs )
852
- if not self . enable_alltoall :
886
+ if not use_all_to_all :
853
887
if num_chunks % 2 == 0 :
854
888
outputs_list [- 1 ] = self .reducescatter_or_allreduce (
855
889
outputs_list [- 1 ],
890
+ use_all_to_all ,
856
891
all_rank_num_tokens = all_rank_num_tokens_list [- 1 ],
857
892
use_dp_padding = use_dp_padding )
858
893
else :
859
894
with torch .cuda .stream (self .aux_stream ):
860
895
outputs_list [- 1 ] = self .reducescatter_or_allreduce (
861
896
outputs_list [- 1 ],
897
+ use_all_to_all ,
862
898
all_rank_num_tokens = all_rank_num_tokens_list [- 1 ],
863
899
use_dp_padding = use_dp_padding )
864
900
with torch .cuda .stream (self .aux_stream ):
@@ -873,7 +909,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
873
909
def alltoall_prepare_maybe_dispatch (
874
910
self , all_rank_max_num_tokens : int , x : torch .Tensor ,
875
911
token_selected_slots : torch .Tensor ,
876
- token_final_scales : torch .Tensor ,
912
+ token_final_scales : torch .Tensor , use_postquant_alltoall : bool ,
877
913
local_statistic_tensor : Optional [torch .Tensor ]):
878
914
top_k = self .routing_method .experts_per_token
879
915
@@ -919,7 +955,7 @@ def alltoall_prepare_maybe_dispatch(
919
955
gathered_token_final_scales , all_rank_max_num_tokens ,
920
956
self .num_slots , top_k , self .ep_rank , self .ep_size )
921
957
922
- if not self . use_postquant_alltoall :
958
+ if not use_postquant_alltoall :
923
959
assert not isinstance (
924
960
x , Fp4QuantizedTensor
925
961
), "pre-quant alltoall doesn't support fp4 tensor"
0 commit comments