Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/llm-api/quickstart_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def add_llm_args(parser):
parser.add_argument('--moe_backend',
type=str,
default='CUTLASS',
choices=['CUTLASS', 'TRTLLM', 'VANILLA'])
choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP'])
parser.add_argument('--enable_attention_dp',
default=False,
action='store_true')
Expand Down
7 changes: 2 additions & 5 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@
from ..modules.rms_norm import RMSNorm
from ..peft.lora.layer import LoraLayer
from ..speculative import MTPEagleWorker, MTPSpecMetadata, MTPWorker
from ..utils import (AuxStreamType, EventType, Fp4QuantizedTensor,
disable_fp4_allgather)
from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
EagerFusionConfig, filter_weights,
register_auto_model)
Expand Down Expand Up @@ -514,9 +513,7 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
if self.use_dp and self.mapping.tp_size > 1:
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
# to reduce allreduce BW
if (disable_fp4_allgather()
and not self.experts.enable_alltoall) or isinstance(
self.experts, TRTLLMGenFusedMoE):
if isinstance(self.experts, TRTLLMGenFusedMoE):
hidden_states = allgather(hidden_states,
self.mapping,
dim=0,
Expand Down
84 changes: 60 additions & 24 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,12 @@ def __init__(
model_config.mapping)
self.deep_ep_buffer.reserve(hidden_size, dtype)
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
self.deep_ep_max_num_tokens = min(model_config.max_num_tokens,
self.moe_max_num_tokens)
self.deep_ep_max_num_tokens = int(
os.environ.get(
"TRTLLM_DEEP_EP_TOKEN_LIMIT",
str(
min(model_config.max_num_tokens,
self.moe_max_num_tokens))))
self.deep_ep_buffer = buffer_pool.get_low_latency_buffer(
model_config.mapping)
self.deep_ep_buffer.reserve(self.deep_ep_max_num_tokens,
Expand Down Expand Up @@ -274,6 +278,25 @@ def enable_alltoall(self):
"""
return self.alltoall_method_type != AlltoallMethodType.NotEnabled

def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
num_rows = sum(all_rank_num_tokens)
return (num_rows + self.moe_max_num_tokens -
1) // self.moe_max_num_tokens

def can_use_alltoall(self, input, all_rank_num_tokens):
# Disable alltoall when chunking is used
if self.calculate_num_chunks(all_rank_num_tokens) > 1:
return False

num_tokens = input.shape[0]

# For DeepEPLowLatency, check if tokens exceed the threshold
if (self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency
and num_tokens > self.deep_ep_max_num_tokens):
return False

return self.enable_alltoall

def _get_quant_method(self):
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True):
Expand Down Expand Up @@ -316,11 +339,12 @@ def dummy_allreduce(self):
def reducescatter_or_allreduce(
self,
inputs,
use_all_to_all: bool,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
):
outputs = inputs
if not self.enable_alltoall:
if not use_all_to_all:
if self.enable_dummy_allreduce:
self.dummy_allreduce()
outputs = reducescatter(
Expand All @@ -334,6 +358,7 @@ def forward_chunk(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
use_all_to_all: bool,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
all_rank_max_num_tokens: Optional[int] = None,
Expand Down Expand Up @@ -382,7 +407,7 @@ def forward_chunk(
) and is_first_call:
self.layer_load_balancer.maybe_cudagraph_done_wait()

use_allgather = not self.enable_alltoall
use_allgather = not use_all_to_all

loadbalancer_local_statistic_info = None
gathered_loadbalancer_local_statistic_info = None
Expand All @@ -391,7 +416,7 @@ def forward_chunk(
token_selected_slots = token_selected_experts
else:
if not self.layer_load_balancer.is_static_routing(
) and self.enable_alltoall:
) and use_all_to_all:
self.layer_load_balancer.local_statistic(
token_selected_experts,
is_first_stage=is_first_call,
Expand All @@ -400,7 +425,7 @@ def forward_chunk(
token_selected_experts, self.use_dp)
if not self.layer_load_balancer.is_static_routing():
# split into two part to get possible overlap with load balancer routing
if self.enable_alltoall:
if use_all_to_all:
if is_last_call:
loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor(
)
Expand All @@ -412,7 +437,9 @@ def forward_chunk(
ExpertStatistic.set_layer(self.layer_idx)
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)

if self.enable_alltoall:
# If alltoall is disabled, we need also disable use_postquant_alltoall
use_postquant_alltoall = self.use_postquant_alltoall and use_all_to_all
if use_all_to_all:
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
if self.enable_dummy_allreduce:
self.dummy_allreduce()
Expand All @@ -423,13 +450,14 @@ def forward_chunk(
x,
token_selected_slots,
token_final_scales,
use_postquant_alltoall,
loadbalancer_local_statistic_info)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
if not self.use_postquant_alltoall:
if not use_postquant_alltoall:
x, recv_topk_idx, token_final_scales, num_recv_tokens_per_expert_list, deep_ep_handle = \
self.deep_ep_buffer.dispatch(x, token_selected_slots.to(torch.int64), token_final_scales, self.num_slots)
elif self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
if not self.use_postquant_alltoall:
if not use_postquant_alltoall:
deep_ep_topk_idx = token_selected_slots.to(torch.int64)
deep_ep_topk_weights = token_final_scales
x, recv_expert_count, deep_ep_handle = \
Expand Down Expand Up @@ -469,7 +497,7 @@ def forward_chunk(
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
x, self.fc31_input_dequant)
elif self.has_nvfp4:
if use_allgather or self.use_postquant_alltoall:
if use_allgather or use_postquant_alltoall:
if isinstance(x, Fp4QuantizedTensor):
if use_allgather:
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before allgather"
Expand Down Expand Up @@ -525,7 +553,7 @@ def forward_chunk(

if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing(
):
if self.enable_alltoall:
if use_all_to_all:
if is_last_call:
gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view(
(self.mapping.moe_ep_size, self.num_experts))
Expand All @@ -545,7 +573,7 @@ def forward_chunk(
cluster_rank = self.cluster_rank
quant_scales = self.quant_scales

if self.use_postquant_alltoall:
if use_postquant_alltoall:
if x_sf is not None and self.has_nvfp4:
assert not x_is_sf_swizzled, "Fp4 scaling factor should not be swizzled before Alltoall"
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
Expand Down Expand Up @@ -636,7 +664,7 @@ def forward_chunk(
f"Not available alltoall method type: {self.alltoall_method_type!r}"
)

if self.enable_alltoall:
if use_all_to_all:
# Adapter between `torch.ops.trtllm.fused_moe` and DeepEP
# TODO: remove the adapter by changing APIs
if self.alltoall_method_type == AlltoallMethodType.DeepEP:
Expand Down Expand Up @@ -676,7 +704,7 @@ def forward_chunk(
ep_rank=ep_rank,
cluster_size=cluster_size,
cluster_rank=cluster_rank,
enable_alltoall=self.enable_alltoall,
enable_alltoall=use_all_to_all,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
use_w4a8_group_scaling=use_w4a8_group_scaling,
min_latency_mode=False,
Expand All @@ -691,7 +719,7 @@ def forward_chunk(
# Otherwise, the output should be unpacked as a single tensor.
final_hidden_states = final_hidden_states[0]

if self.enable_alltoall:
if use_all_to_all:
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
if self.enable_dummy_allreduce:
self.dummy_allreduce()
Expand Down Expand Up @@ -747,11 +775,10 @@ def forward(
) -> torch.Tensor:
assert all_rank_num_tokens is not None
assert use_dp_padding is not None
num_rows = sum(all_rank_num_tokens)

# in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks
num_chunks = (num_rows + self.moe_max_num_tokens -
1) // self.moe_max_num_tokens
num_chunks = self.calculate_num_chunks(all_rank_num_tokens)
use_all_to_all = self.can_use_alltoall(x, all_rank_num_tokens)

if use_dp_padding:
all_rank_num_tokens_padded = [all_rank_max_num_tokens
Expand All @@ -764,13 +791,15 @@ def forward(
outputs = self.forward_chunk(
x,
router_logits,
use_all_to_all,
output_dtype,
all_rank_num_tokens=all_rank_num_tokens_padded,
all_rank_max_num_tokens=all_rank_max_num_tokens,
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
outputs = self.reducescatter_or_allreduce(
outputs,
use_all_to_all,
all_rank_num_tokens=all_rank_num_tokens_padded,
use_dp_padding=use_dp_padding)
else:
Expand All @@ -792,7 +821,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
all_rank_max_num_tokens_list = split_chunk(all_rank_max_num_tokens,
num_chunks)
chunk_size_list = all_rank_chunk_size_list[self.rank]
if self.enable_alltoall:
if use_all_to_all:
all_rank_num_tokens_list = [[
1 if val == 0 else val for val in val_list
] for val_list in all_rank_num_tokens_list]
Expand All @@ -804,7 +833,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
x_list = x.split(chunk_size_list)
router_logits_list = router_logits.split(chunk_size_list)

if not self.enable_alltoall:
if not use_all_to_all:
self.event_dict[EventType.Main].record()
with torch.cuda.stream(self.aux_stream):
self.event_dict[EventType.Main].wait()
Expand All @@ -815,12 +844,13 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
zip(x_list, router_logits_list)):
is_first_call = idx_chunk == 0 and self.repeat_idx == 0
is_last_call = idx_chunk == num_chunks - 1 and self.repeat_idx == self.repeat_count - 1
if not self.enable_alltoall:
if not use_all_to_all:
if idx_chunk % 2 == 0:
with torch.cuda.stream(self.aux_stream):
outputs = self.forward_chunk(
x,
router_logits,
use_all_to_all,
all_rank_num_tokens=all_rank_num_tokens_list[
idx_chunk],
all_rank_max_num_tokens=
Expand All @@ -830,13 +860,15 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
if idx_chunk > 0:
outputs_list[-1] = self.reducescatter_or_allreduce(
outputs_list[-1],
use_all_to_all,
all_rank_num_tokens=all_rank_num_tokens_list[
idx_chunk - 1],
use_dp_padding=use_dp_padding)
else:
outputs = self.forward_chunk(
x,
router_logits,
use_all_to_all,
all_rank_num_tokens=all_rank_num_tokens_list[
idx_chunk],
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
Expand All @@ -846,29 +878,33 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
with torch.cuda.stream(self.aux_stream):
outputs_list[-1] = self.reducescatter_or_allreduce(
outputs_list[-1],
use_all_to_all,
all_rank_num_tokens=all_rank_num_tokens_list[
idx_chunk - 1],
use_dp_padding=use_dp_padding)
else:
outputs = self.forward_chunk(
x,
router_logits,
use_all_to_all,
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk],
all_rank_max_num_tokens=all_rank_max_num_tokens_list[
idx_chunk],
repeating_info=(is_first_call, is_last_call))

outputs_list.append(outputs)
if not self.enable_alltoall:
if not use_all_to_all:
if num_chunks % 2 == 0:
outputs_list[-1] = self.reducescatter_or_allreduce(
outputs_list[-1],
use_all_to_all,
all_rank_num_tokens=all_rank_num_tokens_list[-1],
use_dp_padding=use_dp_padding)
else:
with torch.cuda.stream(self.aux_stream):
outputs_list[-1] = self.reducescatter_or_allreduce(
outputs_list[-1],
use_all_to_all,
all_rank_num_tokens=all_rank_num_tokens_list[-1],
use_dp_padding=use_dp_padding)
with torch.cuda.stream(self.aux_stream):
Expand All @@ -883,7 +919,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
def alltoall_prepare_maybe_dispatch(
self, all_rank_max_num_tokens: int, x: torch.Tensor,
token_selected_slots: torch.Tensor,
token_final_scales: torch.Tensor,
token_final_scales: torch.Tensor, use_postquant_alltoall: bool,
local_statistic_tensor: Optional[torch.Tensor]):
top_k = self.routing_method.experts_per_token

Expand Down Expand Up @@ -929,7 +965,7 @@ def alltoall_prepare_maybe_dispatch(
gathered_token_final_scales, all_rank_max_num_tokens,
self.num_slots, top_k, self.ep_rank, self.ep_size)

if not self.use_postquant_alltoall:
if not use_postquant_alltoall:
assert not isinstance(
x, Fp4QuantizedTensor
), "pre-quant alltoall doesn't support fp4 tensor"
Expand Down