Skip to content

Commit 77d53e9

Browse files
committed
feat: wide_ep support block-wise FP8 on blackwell
Signed-off-by: xxi <[email protected]> modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py new file: tensorrt_llm/_torch/modules/fused_moe/moe_backend.py modified: tests/unittest/_torch/modules/test_fused_moe.py
1 parent e257cb3 commit 77d53e9

File tree

3 files changed

+977
-21
lines changed

3 files changed

+977
-21
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import torch
66

77
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
8-
from tensorrt_llm._utils import logger
8+
from tensorrt_llm._utils import get_sm_version
99
from tensorrt_llm.functional import AllReduceStrategy
10+
from tensorrt_llm.logger import logger
1011
from tensorrt_llm.mapping import Mapping
1112

1213
from ...distributed import AllReduce, allgather, reducescatter
@@ -15,8 +16,10 @@
1516
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor
1617
from .deep_ep_utils import buffer_pool, deep_ep_installed
1718
from .interface import MoE
19+
from .moe_backend import MoEBackend, MoEBackendSelection
1820
from .moe_load_balancer import get_moe_load_balancer
1921
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
22+
DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
2023
FP8QDQFusedMoEMethod, MoEWeightLoadingMode,
2124
NVFP4CutlassFusedMoEMethod,
2225
UnquantizedFusedMoEMethod, WInt4AFP8FusedMoEMethod)
@@ -90,6 +93,9 @@ def __init__(
9093
self.apply_router_weight_on_input = apply_router_weight_on_input
9194
self.layer_idx = layer_idx
9295

96+
# Store original hidden size before any potential padding
97+
self.unpadded_hidden_size = self.hidden_size
98+
9399
moe_load_balancer = get_moe_load_balancer()
94100
self.layer_load_balancer = None
95101
self.repeat_idx = 0
@@ -227,6 +233,9 @@ def __init__(
227233
self.enable_dummy_allreduce = os.environ.get(
228234
"TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1"
229235

236+
# MoE backend will be lazily initialized when first accessed (see moe_backend property)
237+
self._moe_backend_impl = None
238+
230239
def _check_configs(self):
231240
assert self._weights_created
232241

@@ -316,7 +325,10 @@ def _get_quant_method(self):
316325
if self.quant_config.layer_quant_mode.has_fp8_qdq():
317326
return FP8QDQFusedMoEMethod()
318327
elif self.quant_config.layer_quant_mode.has_fp8_block_scales():
319-
return DeepSeekFP8BlockScalesFusedMoEMethod()
328+
if get_sm_version() == 100:
329+
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm()
330+
else:
331+
return DeepSeekFP8BlockScalesFusedMoEMethod()
320332
elif self.quant_config.layer_quant_mode.has_nvfp4():
321333
return NVFP4CutlassFusedMoEMethod()
322334
elif self.quant_config.layer_quant_mode.is_int4_weight_only_per_group(
@@ -339,6 +351,19 @@ def create_weights(self):
339351
self._weights_created = True
340352
self._check_configs()
341353

354+
@property
355+
def moe_backend_impl(self) -> MoEBackend:
356+
"""
357+
Lazily initialize and return the MoE backend.
358+
359+
The backend is selected based on hardware capabilities and quantization
360+
configuration, which are only available after weights are created.
361+
"""
362+
if self._moe_backend_impl is None:
363+
assert self._weights_created, "Weights must be created before accessing moe_backend"
364+
self._moe_backend_impl = MoEBackendSelection.select_backend(self)
365+
return self._moe_backend_impl
366+
342367
def dummy_allreduce(self):
343368
"""
344369
Debug function for eliminating imbalance during performance analysis.
@@ -389,8 +414,6 @@ def forward_chunk(
389414
if self.layer_load_balancer and is_first_call:
390415
self.layer_load_balancer.start_wait_gpu_stage()
391416

392-
use_deepseek_fp8_block_scale = False
393-
use_w4_group_scaling = False
394417
weight_dtype = self.w3_w1_weight.dtype
395418

396419
token_selected_experts, token_final_scales = self.routing_method.apply(
@@ -544,9 +567,8 @@ def forward_chunk(
544567
x_sf = x_sf.view((x_row, -1))
545568

546569
elif self.has_deepseek_fp8_block_scales:
547-
use_deepseek_fp8_block_scale = True
570+
pass
548571
elif self.has_w4afp8:
549-
use_w4_group_scaling = True
550572
weight_dtype = torch.quint4x2
551573
else:
552574
raise ValueError(
@@ -569,12 +591,8 @@ def forward_chunk(
569591
sizes=None if use_dp_padding else all_rank_num_tokens)
570592
x_row = x.shape[0]
571593

572-
ep_size = self.ep_size
573-
ep_rank = self.ep_rank
574594
w3_w1_weight = self.w3_w1_weight
575595
w2_weight = self.w2_weight
576-
cluster_size = self.cluster_size
577-
cluster_rank = self.cluster_rank
578596
quant_scales = self.quant_scales
579597

580598
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
@@ -640,7 +658,8 @@ def forward_chunk(
640658
f"Not available alltoall method type: {self.alltoall_method_type!r}"
641659
)
642660

643-
final_hidden_states = torch.ops.trtllm.fused_moe(
661+
final_hidden_states = self.moe_backend_impl.run_moe(
662+
self,
644663
x,
645664
token_selected_slots,
646665
token_final_scales,
@@ -652,17 +671,8 @@ def forward_chunk(
652671
quant_scales=quant_scales,
653672
input_sf=x_sf,
654673
swizzled_input_sf=False,
655-
tp_size=self.tp_size,
656-
tp_rank=self.tp_rank,
657-
ep_size=ep_size,
658-
ep_rank=ep_rank,
659-
cluster_size=cluster_size,
660-
cluster_rank=cluster_rank,
661-
enable_alltoall=use_all_to_all,
662-
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
663-
use_w4_group_scaling=use_w4_group_scaling,
664674
min_latency_mode=False,
665-
tune_max_num_tokens=self.tune_max_num_tokens,
675+
use_fused_finalize=True,
666676
tuner_num_tokens=tuner_num_tokens,
667677
tuner_top_k=tuner_top_k,
668678
)

0 commit comments

Comments
 (0)