@@ -335,7 +335,7 @@ def reducescatter_or_allreduce(
335335 use_dp_padding : Optional [bool ] = None ,
336336 ):
337337 outputs = inputs
338- if not use_all_toall :
338+ if not use_all_to_all :
339339 if self .enable_dummy_allreduce :
340340 self .dummy_allreduce ()
341341 outputs = reducescatter (
@@ -398,7 +398,7 @@ def forward_chunk(
398398 ) and is_first_call :
399399 self .layer_load_balancer .maybe_cudagraph_done_wait ()
400400
401- use_allgather = not self . enable_alltoall
401+ use_allgather = not use_all_to_all
402402
403403 loadbalancer_local_statistic_info = None
404404 gathered_loadbalancer_local_statistic_info = None
@@ -407,7 +407,7 @@ def forward_chunk(
407407 token_selected_slots = token_selected_experts
408408 else :
409409 if not self .layer_load_balancer .is_static_routing (
410- ) and self . enable_alltoall :
410+ ) and use_all_to_all :
411411 self .layer_load_balancer .local_statistic (
412412 token_selected_experts ,
413413 is_first_stage = is_first_call ,
@@ -416,7 +416,7 @@ def forward_chunk(
416416 token_selected_experts , self .use_dp )
417417 if not self .layer_load_balancer .is_static_routing ():
418418 # split into two part to get possible overlap with load balancer routing
419- if self . enable_alltoall :
419+ if use_all_to_all :
420420 if is_last_call :
421421 loadbalancer_local_statistic_info = self .layer_load_balancer .get_local_statistic_tensor (
422422 )
@@ -428,6 +428,8 @@ def forward_chunk(
428428 ExpertStatistic .set_layer (self .layer_idx )
429429 ExpertStatistic .maybe_add_info (self .num_slots , token_selected_slots )
430430
431+ # If alltoall is disabled, we need also disable use_postquant_alltoall
432+ use_postquant_alltoall = self .use_postquant_alltoall and use_all_to_all
431433 if use_all_to_all :
432434 if self .alltoall_method_type == AlltoallMethodType .MNNVL :
433435 if self .enable_dummy_allreduce :
@@ -439,13 +441,14 @@ def forward_chunk(
439441 x ,
440442 token_selected_slots ,
441443 token_final_scales ,
444+ use_postquant_alltoall ,
442445 loadbalancer_local_statistic_info )
443446 elif self .alltoall_method_type == AlltoallMethodType .DeepEP :
444- if not self . use_postquant_alltoall :
447+ if not use_postquant_alltoall :
445448 x , recv_topk_idx , token_final_scales , num_recv_tokens_per_expert_list , deep_ep_handle = \
446449 self .deep_ep_buffer .dispatch (x , token_selected_slots .to (torch .int64 ), token_final_scales , self .num_slots )
447450 elif self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency :
448- if not self . use_postquant_alltoall :
451+ if not use_postquant_alltoall :
449452 deep_ep_topk_idx = token_selected_slots .to (torch .int64 )
450453 deep_ep_topk_weights = token_final_scales
451454 x , recv_expert_count , deep_ep_handle = \
@@ -485,7 +488,7 @@ def forward_chunk(
485488 x , _ = torch .ops .tensorrt_llm .static_quantize_e4m3_per_tensor (
486489 x , self .fc31_input_dequant )
487490 elif self .has_nvfp4 :
488- if use_allgather or self . use_postquant_alltoall :
491+ if use_allgather or use_postquant_alltoall :
489492 if isinstance (x , Fp4QuantizedTensor ):
490493 if use_allgather :
491494 assert not x .is_sf_swizzled , "Fp4QuantizedTensor should not be swizzled before allgather"
@@ -518,7 +521,7 @@ def forward_chunk(
518521 f"unsupported quantization mode: { self .quant_config .quant_mode } "
519522 )
520523
521- if use_allgather and not use_all_to_all ::
524+ if use_allgather and not use_all_to_all :
522525 # using allgather case.
523526 if self .enable_dummy_allreduce :
524527 self .dummy_allreduce ()
@@ -541,7 +544,7 @@ def forward_chunk(
541544
542545 if self .layer_load_balancer and not self .layer_load_balancer .is_static_routing (
543546 ):
544- if self . enable_alltoall :
547+ if use_all_to_all :
545548 if is_last_call :
546549 gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info .view (
547550 (self .mapping .moe_ep_size , self .num_experts ))
@@ -561,7 +564,7 @@ def forward_chunk(
561564 cluster_rank = self .cluster_rank
562565 quant_scales = self .quant_scales
563566
564- if self . use_postquant_alltoall :
567+ if use_postquant_alltoall :
565568 if x_sf is not None and self .has_nvfp4 :
566569 assert not x_is_sf_swizzled , "Fp4 scaling factor should not be swizzled before Alltoall"
567570 if self .alltoall_method_type == AlltoallMethodType .MNNVL :
@@ -814,7 +817,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
814817 all_rank_max_num_tokens_list = split_chunk (all_rank_max_num_tokens ,
815818 num_chunks )
816819 chunk_size_list = all_rank_chunk_size_list [self .rank ]
817- if self . enable_alltoall :
820+ if use_all_to_all :
818821 all_rank_num_tokens_list = [[
819822 1 if val == 0 else val for val in val_list
820823 ] for val_list in all_rank_num_tokens_list ]
@@ -912,7 +915,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
912915 def alltoall_prepare_maybe_dispatch (
913916 self , all_rank_max_num_tokens : int , x : torch .Tensor ,
914917 token_selected_slots : torch .Tensor ,
915- token_final_scales : torch .Tensor ,
918+ token_final_scales : torch .Tensor , use_postquant_alltoall : bool ,
916919 local_statistic_tensor : Optional [torch .Tensor ]):
917920 top_k = self .routing_method .experts_per_token
918921
@@ -958,7 +961,7 @@ def alltoall_prepare_maybe_dispatch(
958961 gathered_token_final_scales , all_rank_max_num_tokens ,
959962 self .num_slots , top_k , self .ep_rank , self .ep_size )
960963
961- if not self . use_postquant_alltoall :
964+ if not use_postquant_alltoall :
962965 assert not isinstance (
963966 x , Fp4QuantizedTensor
964967 ), "pre-quant alltoall doesn't support fp4 tensor"
0 commit comments