|
41 | 41 |
|
42 | 42 | from tensorrt_llm._ipc_utils import can_access_peer |
43 | 43 | from tensorrt_llm._utils import get_sm_version |
44 | | -from tensorrt_llm.functional import PositionEmbeddingType |
| 44 | +from tensorrt_llm.functional import AllReduceStrategy, PositionEmbeddingType |
45 | 45 | from tensorrt_llm.llmapi.utils import enable_llm_debug |
46 | 46 | from tensorrt_llm.mapping import Mapping |
47 | 47 | from tensorrt_llm.models.modeling_utils import QuantConfig |
|
52 | 52 | from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams |
53 | 53 | from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, |
54 | 54 | MoEAllReduce, MoEAllReduceParams, allgather) |
| 55 | +from ..distributed.ops import MNNVLAllReduce |
55 | 56 | from ..model_config import ModelConfig |
56 | 57 | from ..modules.attention import MLA |
57 | 58 | from ..modules.decoder_layer import DecoderLayer |
@@ -738,10 +739,24 @@ def _compute_mlp_tp_size(self, intermediate_size: int, |
738 | 739 | intermediate_size // block_size, |
739 | 740 | self.mapping.tp_size, |
740 | 741 | ) |
741 | | - mlp_tp_size = math.gcd( |
742 | | - tp, |
743 | | - self.mapping.gpus_per_node, |
744 | | - ) if tp > self.mapping.gpus_per_node else tp # Avoid costly inter-node TP |
| 742 | + # mlp_tp_size = math.gcd( |
| 743 | + # tp, |
| 744 | + # self.mapping.gpus_per_node, |
| 745 | + # ) if tp > self.mapping.gpus_per_node else tp # Avoid costly inter-node TP |
| 746 | + if tp > self.mapping.gpus_per_node and ( |
| 747 | + self.model_config.allreduce_strategy not in ( |
| 748 | + AllReduceStrategy.AUTO, |
| 749 | + AllReduceStrategy.MNNVL, |
| 750 | + ) or not MNNVLAllReduce.is_mnnvl( |
| 751 | + self.mapping, |
| 752 | + self.model_config.pretrained_config.torch_dtype)): |
| 753 | + mlp_tp_size = math.gcd( |
| 754 | + tp, |
| 755 | + self.mapping.gpus_per_node, |
| 756 | + ) # Avoid costly inter-node TP when MNNVL is not supported and tp > gpus_per_node |
| 757 | + else: |
| 758 | + mlp_tp_size = tp |
| 759 | + |
745 | 760 | return mlp_tp_size |
746 | 761 |
|
747 | 762 | def forward( |
|
0 commit comments