Skip to content

Commit 1646384

Browse files
bobbolievezhier
authored andcommitted
chore: Cleanup disable_fp4_allgather. (#6006)
Signed-off-by: Bo Li <[email protected]>
1 parent 8907a65 commit 1646384

File tree

4 files changed

+3
-20
lines changed

4 files changed

+3
-20
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -511,8 +511,6 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
511511
# max-throughput
512512
use_dp_padding = False
513513
if self.use_dp and self.mapping.tp_size > 1:
514-
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
515-
# to reduce allreduce BW
516514
if isinstance(self.experts, TRTLLMGenFusedMoE):
517515
hidden_states = allgather(hidden_states,
518516
self.mapping,

tensorrt_llm/_torch/models/modeling_qwen3_moe.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from ..modules.linear import TensorParallelMode
2121
from ..modules.rms_norm import RMSNorm
2222
from ..speculative import SpecMetadata
23-
from ..utils import disable_fp4_allgather
2423
from .modeling_qwen3 import Qwen3Attention
2524
from .modeling_speculative import SpecDecOneEngineForCausalLM
2625
from .modeling_utils import (DecoderModel, EagerFusionConfig,
@@ -133,11 +132,7 @@ def forward(
133132
assert not self.enable_attention_dp
134133

135134
if self.enable_attention_dp and self.mapping.tp_size > 1:
136-
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
137-
# to reduce allreduce BW
138-
if (disable_fp4_allgather()
139-
and not self.experts.enable_alltoall) or isinstance(
140-
self.experts, TRTLLMGenFusedMoE):
135+
if isinstance(self.experts, TRTLLMGenFusedMoE):
141136
hidden_states = allgather(hidden_states,
142137
self.mapping,
143138
dim=0,

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
from ...distributed import allgather, reducescatter
66
from ...model_config import ModelConfig
7-
from ...utils import (EventType, Fp4QuantizedTensor, ceil_div,
8-
disable_fp4_allgather, swizzle_sf)
7+
from ...utils import EventType, Fp4QuantizedTensor, ceil_div, swizzle_sf
98
from .interface import MoE
109
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
1110
FP8QDQFusedMoEMethod, MoEWeightLoadingMode,
@@ -220,8 +219,7 @@ def forward_chunk(
220219
# TODO: remove this once we have correct fusedmoe kernel ready
221220
token_final_scales = None
222221

223-
use_allgather = self.use_dp and self.parallel_size > 1 and not disable_fp4_allgather(
224-
)
222+
use_allgather = self.use_dp and self.parallel_size > 1
225223

226224
# quantize inputs
227225
use_deepseek_fp8_block_scale = False

tensorrt_llm/_torch/utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import contextlib
2-
import os
32
import threading
43
from dataclasses import dataclass
54
from enum import Enum
@@ -100,13 +99,6 @@ def shape(self):
10099
return self.fp4_tensor.shape
101100

102101

103-
_disable_fp4_allgather = os.getenv("TLLM_DISABLE_FP4_ALLGATHER", "0") == "1"
104-
105-
106-
def disable_fp4_allgather():
107-
return _disable_fp4_allgather
108-
109-
110102
def compute_swizzled_sf_shape(row: int, col: int):
111103
padded_row = pad_up(row, 128)
112104
padded_col = pad_up(col, 4)

0 commit comments

Comments
 (0)