Skip to content

Commit 373e11f

Browse files
committed
fix hang
Signed-off-by: xxi <[email protected]>
1 parent 74c3b2a commit 373e11f

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed

tensorrt_llm/_torch/distributed/ops.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,16 @@ def __init__(self,
455455
self.workspace = get_allreduce_workspace(self.mapping)
456456

457457
# Initialize MNNVL AllReduce if needed
458-
if self.strategy == AllReduceStrategy.MNNVL:
459-
if MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
458+
# if self.strategy == AllReduceStrategy.MNNVL:
459+
# if MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
460+
if self.strategy in (AllReduceStrategy.AUTO,
461+
AllReduceStrategy.MNNVL):
462+
if self.mapping.tp_size != self.mapping.world_size:
463+
logger.debug(
464+
f"MNNVLAllReduce is disabled due to tp_size:{self.mapping.tp_size} "
465+
f"!= world_size:{self.mapping.world_size}")
466+
self.mnnvl_allreduce = None
467+
elif MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
460468
try:
461469
self.mnnvl_allreduce = MNNVLAllReduce(
462470
self.mapping, dtype) if dtype else None

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
from tensorrt_llm._ipc_utils import can_access_peer
4343
from tensorrt_llm._utils import get_sm_version
44-
from tensorrt_llm.functional import PositionEmbeddingType
44+
from tensorrt_llm.functional import AllReduceStrategy, PositionEmbeddingType
4545
from tensorrt_llm.llmapi.utils import enable_llm_debug
4646
from tensorrt_llm.mapping import Mapping
4747
from tensorrt_llm.models.modeling_utils import QuantConfig
@@ -52,6 +52,7 @@
5252
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
5353
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
5454
MoEAllReduce, MoEAllReduceParams, allgather)
55+
from ..distributed.ops import MNNVLAllReduce
5556
from ..model_config import ModelConfig
5657
from ..modules.attention import MLA
5758
from ..modules.decoder_layer import DecoderLayer
@@ -738,10 +739,24 @@ def _compute_mlp_tp_size(self, intermediate_size: int,
738739
intermediate_size // block_size,
739740
self.mapping.tp_size,
740741
)
741-
mlp_tp_size = math.gcd(
742-
tp,
743-
self.mapping.gpus_per_node,
744-
) if tp > self.mapping.gpus_per_node else tp # Avoid costly inter-node TP
742+
# mlp_tp_size = math.gcd(
743+
# tp,
744+
# self.mapping.gpus_per_node,
745+
# ) if tp > self.mapping.gpus_per_node else tp # Avoid costly inter-node TP
746+
if tp > self.mapping.gpus_per_node and (
747+
self.model_config.allreduce_strategy not in (
748+
AllReduceStrategy.AUTO,
749+
AllReduceStrategy.MNNVL,
750+
) or not MNNVLAllReduce.is_mnnvl(
751+
self.mapping,
752+
self.model_config.pretrained_config.torch_dtype)):
753+
mlp_tp_size = math.gcd(
754+
tp,
755+
self.mapping.gpus_per_node,
756+
) # Avoid costly inter-node TP when MNNVL is not supported and tp > gpus_per_node
757+
else:
758+
mlp_tp_size = tp
759+
745760
return mlp_tp_size
746761

747762
def forward(

0 commit comments

Comments
 (0)