@@ -334,7 +334,7 @@ def reducescatter_or_allreduce(
334334 use_dp_padding : Optional [bool ] = None ,
335335 ):
336336 outputs = inputs
337- if not use_all_toall :
337+ if not use_all_to_all :
338338 if self .enable_dummy_allreduce :
339339 self .dummy_allreduce ()
340340 outputs = reducescatter (
@@ -385,7 +385,7 @@ def forward_chunk(
385385 ) and is_first_call :
386386 self .layer_load_balancer .maybe_cudagraph_done_wait ()
387387
388- use_allgather = not self . enable_alltoall
388+ use_allgather = not use_all_to_all
389389
390390 loadbalancer_local_statistic_info = None
391391 gathered_loadbalancer_local_statistic_info = None
@@ -394,7 +394,7 @@ def forward_chunk(
394394 token_selected_slots = token_selected_experts
395395 else :
396396 if not self .layer_load_balancer .is_static_routing (
397- ) and self . enable_alltoall :
397+ ) and use_all_to_all :
398398 self .layer_load_balancer .local_statistic (
399399 token_selected_experts ,
400400 is_first_stage = is_first_call ,
@@ -403,7 +403,7 @@ def forward_chunk(
403403 token_selected_experts , self .use_dp )
404404 if not self .layer_load_balancer .is_static_routing ():
405405 # split into two part to get possible overlap with load balancer routing
406- if self . enable_alltoall :
406+ if use_all_to_all :
407407 if is_last_call :
408408 loadbalancer_local_statistic_info = self .layer_load_balancer .get_local_statistic_tensor (
409409 )
@@ -415,6 +415,8 @@ def forward_chunk(
415415 ExpertStatistic .set_layer (self .layer_idx )
416416 ExpertStatistic .maybe_add_info (self .num_slots , token_selected_slots )
417417
418+ # If alltoall is disabled, we need also disable use_postquant_alltoall
419+ use_postquant_alltoall = self .use_postquant_alltoall and use_all_to_all
418420 if use_all_to_all :
419421 if self .alltoall_method_type == AlltoallMethodType .MNNVL :
420422 if self .enable_dummy_allreduce :
@@ -426,13 +428,14 @@ def forward_chunk(
426428 x ,
427429 token_selected_slots ,
428430 token_final_scales ,
431+ use_postquant_alltoall ,
429432 loadbalancer_local_statistic_info )
430433 elif self .alltoall_method_type == AlltoallMethodType .DeepEP :
431- if not self . use_postquant_alltoall :
434+ if not use_postquant_alltoall :
432435 x , recv_topk_idx , token_final_scales , num_recv_tokens_per_expert_list , deep_ep_handle = \
433436 self .deep_ep_buffer .dispatch (x , token_selected_slots .to (torch .int64 ), token_final_scales , self .num_slots )
434437 elif self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency :
435- if not self . use_postquant_alltoall :
438+ if not use_postquant_alltoall :
436439 deep_ep_topk_idx = token_selected_slots .to (torch .int64 )
437440 deep_ep_topk_weights = token_final_scales
438441 x , recv_expert_count , deep_ep_handle = \
@@ -471,7 +474,7 @@ def forward_chunk(
471474 x , _ = torch .ops .tensorrt_llm .static_quantize_e4m3_per_tensor (
472475 x , self .fc31_input_dequant )
473476 elif self .has_nvfp4 :
474- if use_allgather or self . use_postquant_alltoall :
477+ if use_allgather or use_postquant_alltoall :
475478 if isinstance (x , Fp4QuantizedTensor ):
476479 if use_allgather :
477480 assert not x .is_sf_swizzled , "Fp4QuantizedTensor should not be swizzled before allgather"
@@ -504,7 +507,7 @@ def forward_chunk(
504507 f"unsupported quantization mode: { self .quant_config .quant_mode } "
505508 )
506509
507- if use_allgather and not use_all_to_all ::
510+ if use_allgather and not use_all_to_all :
508511 # using allgather case.
509512 if self .enable_dummy_allreduce :
510513 self .dummy_allreduce ()
@@ -527,7 +530,7 @@ def forward_chunk(
527530
528531 if self .layer_load_balancer and not self .layer_load_balancer .is_static_routing (
529532 ):
530- if self . enable_alltoall :
533+ if use_all_to_all :
531534 if is_last_call :
532535 gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info .view (
533536 (self .mapping .moe_ep_size , self .num_experts ))
@@ -547,7 +550,7 @@ def forward_chunk(
547550 cluster_rank = self .cluster_rank
548551 quant_scales = self .quant_scales
549552
550- if self . use_postquant_alltoall :
553+ if use_postquant_alltoall :
551554 if x_sf is not None and self .has_nvfp4 :
552555 assert not x_is_sf_swizzled , "Fp4 scaling factor should not be swizzled before Alltoall"
553556 if self .alltoall_method_type == AlltoallMethodType .MNNVL :
@@ -800,7 +803,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
800803 all_rank_max_num_tokens_list = split_chunk (all_rank_max_num_tokens ,
801804 num_chunks )
802805 chunk_size_list = all_rank_chunk_size_list [self .rank ]
803- if self . enable_alltoall :
806+ if use_all_to_all :
804807 all_rank_num_tokens_list = [[
805808 1 if val == 0 else val for val in val_list
806809 ] for val_list in all_rank_num_tokens_list ]
@@ -898,7 +901,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
898901 def alltoall_prepare_maybe_dispatch (
899902 self , all_rank_max_num_tokens : int , x : torch .Tensor ,
900903 token_selected_slots : torch .Tensor ,
901- token_final_scales : torch .Tensor ,
904+ token_final_scales : torch .Tensor , use_postquant_alltoall : bool ,
902905 local_statistic_tensor : Optional [torch .Tensor ]):
903906 top_k = self .routing_method .experts_per_token
904907
@@ -942,7 +945,7 @@ def alltoall_prepare_maybe_dispatch(
942945 gathered_token_final_scales , all_rank_max_num_tokens ,
943946 self .num_slots , top_k , self .ep_rank , self .ep_size )
944947
945- if not self . use_postquant_alltoall :
948+ if not use_postquant_alltoall :
946949 assert not isinstance (
947950 x , Fp4QuantizedTensor
948951 ), "pre-quant alltoall doesn't support fp4 tensor"
0 commit comments