Skip to content

Commit 74c3b2a

Browse files
committed
feat:add block fp8 on GB200 in wide ep
Signed-off-by: xxi <[email protected]> modified: tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py new file: tensorrt_llm/_torch/modules/fused_moe/moe_backend.py
1 parent 4017f7c commit 74c3b2a

File tree

2 files changed

+910
-2
lines changed

2 files changed

+910
-2
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66

77
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
8-
from tensorrt_llm._utils import logger
8+
from tensorrt_llm._utils import get_sm_version, logger
99
from tensorrt_llm.functional import AllReduceStrategy
1010
from tensorrt_llm.mapping import Mapping
1111

@@ -15,8 +15,10 @@
1515
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor
1616
from .deep_ep_utils import buffer_pool, deep_ep_installed
1717
from .interface import MoE
18+
from .moe_backend import MoEBackendSelection
1819
from .moe_load_balancer import get_moe_load_balancer
1920
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
21+
DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
2022
FP8QDQFusedMoEMethod, MoEWeightLoadingMode,
2123
NVFP4CutlassFusedMoEMethod,
2224
UnquantizedFusedMoEMethod, WInt4AFP8FusedMoEMethod)
@@ -232,6 +234,9 @@ def __init__(
232234
self.enable_dummy_allreduce = os.environ.get(
233235
"TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1"
234236

237+
# Select MoE backend based on configuration
238+
self.moe_backend = None # Will be initialized after weights are created
239+
235240
def _check_configs(self):
236241
assert self._weights_created
237242

@@ -318,7 +323,10 @@ def _get_quant_method(self):
318323
if self.quant_config.layer_quant_mode.has_fp8_qdq():
319324
return FP8QDQFusedMoEMethod()
320325
elif self.quant_config.layer_quant_mode.has_fp8_block_scales():
321-
return DeepSeekFP8BlockScalesFusedMoEMethod()
326+
if get_sm_version() == 100:
327+
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm()
328+
else:
329+
return DeepSeekFP8BlockScalesFusedMoEMethod()
322330
elif self.quant_config.layer_quant_mode.has_nvfp4():
323331
return NVFP4CutlassFusedMoEMethod()
324332
elif self.quant_config.layer_quant_mode.is_int4_weight_only_per_group(
@@ -341,6 +349,9 @@ def create_weights(self):
341349
self._weights_created = True
342350
self._check_configs()
343351

352+
# Initialize MoE backend after weights are created
353+
self.moe_backend = MoEBackendSelection.select_backend(self)
354+
344355
def dummy_allreduce(self):
345356
"""
346357
Debug function for eliminating imbalance during performance analysis.
@@ -638,6 +649,7 @@ def forward_chunk(
638649
f"Not available alltoall method type: {self.alltoall_method_type!r}"
639650
)
640651

652+
# Original fused_moe call (preserved as reference)
641653
final_hidden_states = torch.ops.trtllm.fused_moe(
642654
x,
643655
token_selected_slots,
@@ -665,6 +677,35 @@ def forward_chunk(
665677
tuner_top_k=tuner_top_k,
666678
)
667679

680+
# Use the selected backend to compute MoE with the same parameters as fused_moe
681+
# final_hidden_states = self.moe_backend.run_moe(
682+
# x,
683+
# token_selected_slots,
684+
# token_final_scales,
685+
# w3_w1_weight.view(weight_dtype),
686+
# None, # w3_w1_bias
687+
# w2_weight.view(weight_dtype),
688+
# None, # w2_bias
689+
# output_dtype,
690+
# quant_scales=quant_scales,
691+
# input_sf=x_sf,
692+
# swizzled_input_sf=False,
693+
# tp_size=self.tp_size,
694+
# tp_rank=self.tp_rank,
695+
# ep_size=ep_size,
696+
# ep_rank=ep_rank,
697+
# cluster_size=cluster_size,
698+
# cluster_rank=cluster_rank,
699+
# enable_alltoall=use_all_to_all,
700+
# use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
701+
# use_w4_group_scaling=use_w4_group_scaling,
702+
# min_latency_mode=False,
703+
# tune_max_num_tokens=self.tune_max_num_tokens,
704+
# tuner_num_tokens=tuner_num_tokens,
705+
# tuner_top_k=tuner_top_k,
706+
# module=self, # Additional parameter for backend to access module properties
707+
# )
708+
668709
if self.layer_load_balancer and is_last_call:
669710
self.layer_load_balancer.start_set_cpu_stage()
670711

0 commit comments

Comments
 (0)