diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 352a23893ca..68f3be8e7cc 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -50,7 +50,7 @@ def add_llm_args(parser): parser.add_argument('--moe_backend', type=str, default='CUTLASS', - choices=['CUTLASS', 'TRTLLM', 'VANILLA']) + choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP']) parser.add_argument('--enable_attention_dp', default=False, action='store_true') diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 1834e7b1476..b92cef4dc54 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -62,8 +62,7 @@ from ..modules.rms_norm import RMSNorm from ..peft.lora.layer import LoraLayer from ..speculative import MTPEagleWorker, MTPSpecMetadata, MTPWorker -from ..utils import (AuxStreamType, EventType, Fp4QuantizedTensor, - disable_fp4_allgather) +from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, EagerFusionConfig, filter_weights, register_auto_model) @@ -514,9 +513,7 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4, if self.use_dp and self.mapping.tp_size > 1: # FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization # to reduce allreduce BW - if (disable_fp4_allgather() - and not self.experts.enable_alltoall) or isinstance( - self.experts, TRTLLMGenFusedMoE): + if isinstance(self.experts, TRTLLMGenFusedMoE): hidden_states = allgather(hidden_states, self.mapping, dim=0, diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 9290aae3029..f2821c8fa93 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -192,8 +192,12 @@ def __init__( model_config.mapping) self.deep_ep_buffer.reserve(hidden_size, dtype) elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: - self.deep_ep_max_num_tokens = min(model_config.max_num_tokens, - self.moe_max_num_tokens) + self.deep_ep_max_num_tokens = int( + os.environ.get( + "TRTLLM_DEEP_EP_TOKEN_LIMIT", + str( + min(model_config.max_num_tokens, + self.moe_max_num_tokens)))) self.deep_ep_buffer = buffer_pool.get_low_latency_buffer( model_config.mapping) self.deep_ep_buffer.reserve(self.deep_ep_max_num_tokens, @@ -274,6 +278,25 @@ def enable_alltoall(self): """ return self.alltoall_method_type != AlltoallMethodType.NotEnabled + def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int: + num_rows = sum(all_rank_num_tokens) + return (num_rows + self.moe_max_num_tokens - + 1) // self.moe_max_num_tokens + + def can_use_alltoall(self, input, all_rank_num_tokens): + # Disable alltoall when chunking is used + if self.calculate_num_chunks(all_rank_num_tokens) > 1: + return False + + num_tokens = input.shape[0] + + # For DeepEPLowLatency, check if tokens exceed the threshold + if (self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency + and num_tokens > self.deep_ep_max_num_tokens): + return False + + return self.enable_alltoall + def _get_quant_method(self): if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( exclude_kv_cache=True): @@ -316,11 +339,12 @@ def dummy_allreduce(self): def reducescatter_or_allreduce( self, inputs, + use_all_to_all: bool, all_rank_num_tokens: Optional[List[int]] = None, use_dp_padding: Optional[bool] = None, ): outputs = inputs - if not self.enable_alltoall: + if not use_all_to_all: if self.enable_dummy_allreduce: self.dummy_allreduce() outputs = reducescatter( @@ -334,6 +358,7 @@ def forward_chunk( self, x: Union[torch.Tensor, Fp4QuantizedTensor], router_logits: torch.Tensor, + use_all_to_all: bool, output_dtype: Optional[torch.dtype] = None, all_rank_num_tokens: Optional[List[int]] = None, all_rank_max_num_tokens: Optional[int] = None, @@ -382,7 +407,7 @@ def forward_chunk( ) and is_first_call: self.layer_load_balancer.maybe_cudagraph_done_wait() - use_allgather = not self.enable_alltoall + use_allgather = not use_all_to_all loadbalancer_local_statistic_info = None gathered_loadbalancer_local_statistic_info = None @@ -391,7 +416,7 @@ def forward_chunk( token_selected_slots = token_selected_experts else: if not self.layer_load_balancer.is_static_routing( - ) and self.enable_alltoall: + ) and use_all_to_all: self.layer_load_balancer.local_statistic( token_selected_experts, is_first_stage=is_first_call, @@ -400,7 +425,7 @@ def forward_chunk( token_selected_experts, self.use_dp) if not self.layer_load_balancer.is_static_routing(): # split into two part to get possible overlap with load balancer routing - if self.enable_alltoall: + if use_all_to_all: if is_last_call: loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor( ) @@ -412,7 +437,9 @@ def forward_chunk( ExpertStatistic.set_layer(self.layer_idx) ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots) - if self.enable_alltoall: + # If alltoall is disabled, we need also disable use_postquant_alltoall + use_postquant_alltoall = self.use_postquant_alltoall and use_all_to_all + if use_all_to_all: if self.alltoall_method_type == AlltoallMethodType.MNNVL: if self.enable_dummy_allreduce: self.dummy_allreduce() @@ -423,13 +450,14 @@ def forward_chunk( x, token_selected_slots, token_final_scales, + use_postquant_alltoall, loadbalancer_local_statistic_info) elif self.alltoall_method_type == AlltoallMethodType.DeepEP: - if not self.use_postquant_alltoall: + if not use_postquant_alltoall: x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \ self.deep_ep_buffer.dispatch(x, token_selected_slots.to(torch.int64), token_final_scales, self.num_slots) elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency: - if not self.use_postquant_alltoall: + if not use_postquant_alltoall: deep_ep_topk_idx = token_selected_slots.to(torch.int64) deep_ep_topk_weights = token_final_scales x, recv_expert_count, deep_ep_handle = \ @@ -469,7 +497,7 @@ def forward_chunk( x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor( x, self.fc31_input_dequant) elif self.has_nvfp4: - if use_allgather or self.use_postquant_alltoall: + if use_allgather or use_postquant_alltoall: if isinstance(x, Fp4QuantizedTensor): if use_allgather: assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before allgather" @@ -525,7 +553,7 @@ def forward_chunk( if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( ): - if self.enable_alltoall: + if use_all_to_all: if is_last_call: gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view( (self.mapping.moe_ep_size, self.num_experts)) @@ -545,7 +573,7 @@ def forward_chunk( cluster_rank = self.cluster_rank quant_scales = self.quant_scales - if self.use_postquant_alltoall: + if use_postquant_alltoall: if x_sf is not None and self.has_nvfp4: assert not x_is_sf_swizzled, "Fp4 scaling factor should not be swizzled before Alltoall" if self.alltoall_method_type == AlltoallMethodType.MNNVL: @@ -636,7 +664,7 @@ def forward_chunk( f"Not available alltoall method type: {self.alltoall_method_type!r}" ) - if self.enable_alltoall: + if use_all_to_all: # Adapter between `torch.ops.trtllm.fused_moe` and DeepEP # TODO: remove the adapter by changing APIs if self.alltoall_method_type == AlltoallMethodType.DeepEP: @@ -676,7 +704,7 @@ def forward_chunk( ep_rank=ep_rank, cluster_size=cluster_size, cluster_rank=cluster_rank, - enable_alltoall=self.enable_alltoall, + enable_alltoall=use_all_to_all, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, use_w4a8_group_scaling=use_w4a8_group_scaling, min_latency_mode=False, @@ -691,7 +719,7 @@ def forward_chunk( # Otherwise, the output should be unpacked as a single tensor. final_hidden_states = final_hidden_states[0] - if self.enable_alltoall: + if use_all_to_all: if self.alltoall_method_type == AlltoallMethodType.MNNVL: if self.enable_dummy_allreduce: self.dummy_allreduce() @@ -747,11 +775,10 @@ def forward( ) -> torch.Tensor: assert all_rank_num_tokens is not None assert use_dp_padding is not None - num_rows = sum(all_rank_num_tokens) # in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks - num_chunks = (num_rows + self.moe_max_num_tokens - - 1) // self.moe_max_num_tokens + num_chunks = self.calculate_num_chunks(all_rank_num_tokens) + use_all_to_all = self.can_use_alltoall(x, all_rank_num_tokens) if use_dp_padding: all_rank_num_tokens_padded = [all_rank_max_num_tokens @@ -764,6 +791,7 @@ def forward( outputs = self.forward_chunk( x, router_logits, + use_all_to_all, output_dtype, all_rank_num_tokens=all_rank_num_tokens_padded, all_rank_max_num_tokens=all_rank_max_num_tokens, @@ -771,6 +799,7 @@ def forward( repeating_info=(is_first_call, is_last_call)) outputs = self.reducescatter_or_allreduce( outputs, + use_all_to_all, all_rank_num_tokens=all_rank_num_tokens_padded, use_dp_padding=use_dp_padding) else: @@ -792,7 +821,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int): all_rank_max_num_tokens_list = split_chunk(all_rank_max_num_tokens, num_chunks) chunk_size_list = all_rank_chunk_size_list[self.rank] - if self.enable_alltoall: + if use_all_to_all: all_rank_num_tokens_list = [[ 1 if val == 0 else val for val in val_list ] for val_list in all_rank_num_tokens_list] @@ -804,7 +833,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int): x_list = x.split(chunk_size_list) router_logits_list = router_logits.split(chunk_size_list) - if not self.enable_alltoall: + if not use_all_to_all: self.event_dict[EventType.Main].record() with torch.cuda.stream(self.aux_stream): self.event_dict[EventType.Main].wait() @@ -815,12 +844,13 @@ def split_chunk(split_token_num: int, split_num_chunks: int): zip(x_list, router_logits_list)): is_first_call = idx_chunk == 0 and self.repeat_idx == 0 is_last_call = idx_chunk == num_chunks - 1 and self.repeat_idx == self.repeat_count - 1 - if not self.enable_alltoall: + if not use_all_to_all: if idx_chunk % 2 == 0: with torch.cuda.stream(self.aux_stream): outputs = self.forward_chunk( x, router_logits, + use_all_to_all, all_rank_num_tokens=all_rank_num_tokens_list[ idx_chunk], all_rank_max_num_tokens= @@ -830,6 +860,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int): if idx_chunk > 0: outputs_list[-1] = self.reducescatter_or_allreduce( outputs_list[-1], + use_all_to_all, all_rank_num_tokens=all_rank_num_tokens_list[ idx_chunk - 1], use_dp_padding=use_dp_padding) @@ -837,6 +868,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int): outputs = self.forward_chunk( x, router_logits, + use_all_to_all, all_rank_num_tokens=all_rank_num_tokens_list[ idx_chunk], all_rank_max_num_tokens=all_rank_max_num_tokens_list[ @@ -846,6 +878,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int): with torch.cuda.stream(self.aux_stream): outputs_list[-1] = self.reducescatter_or_allreduce( outputs_list[-1], + use_all_to_all, all_rank_num_tokens=all_rank_num_tokens_list[ idx_chunk - 1], use_dp_padding=use_dp_padding) @@ -853,22 +886,25 @@ def split_chunk(split_token_num: int, split_num_chunks: int): outputs = self.forward_chunk( x, router_logits, + use_all_to_all, all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk], all_rank_max_num_tokens=all_rank_max_num_tokens_list[ idx_chunk], repeating_info=(is_first_call, is_last_call)) outputs_list.append(outputs) - if not self.enable_alltoall: + if not use_all_to_all: if num_chunks % 2 == 0: outputs_list[-1] = self.reducescatter_or_allreduce( outputs_list[-1], + use_all_to_all, all_rank_num_tokens=all_rank_num_tokens_list[-1], use_dp_padding=use_dp_padding) else: with torch.cuda.stream(self.aux_stream): outputs_list[-1] = self.reducescatter_or_allreduce( outputs_list[-1], + use_all_to_all, all_rank_num_tokens=all_rank_num_tokens_list[-1], use_dp_padding=use_dp_padding) with torch.cuda.stream(self.aux_stream): @@ -883,7 +919,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int): def alltoall_prepare_maybe_dispatch( self, all_rank_max_num_tokens: int, x: torch.Tensor, token_selected_slots: torch.Tensor, - token_final_scales: torch.Tensor, + token_final_scales: torch.Tensor, use_postquant_alltoall: bool, local_statistic_tensor: Optional[torch.Tensor]): top_k = self.routing_method.experts_per_token @@ -929,7 +965,7 @@ def alltoall_prepare_maybe_dispatch( gathered_token_final_scales, all_rank_max_num_tokens, self.num_slots, top_k, self.ep_rank, self.ep_size) - if not self.use_postquant_alltoall: + if not use_postquant_alltoall: assert not isinstance( x, Fp4QuantizedTensor ), "pre-quant alltoall doesn't support fp4 tensor"