Skip to content

Commit 2101d46

Browse files
[TRTLLM-6342][feat] TP Sharding read from the model config (#6972)
Signed-off-by: greg-kwasniewski1 <[email protected]> Co-authored-by: Suyog Gupta <[email protected]>
1 parent 97d550b commit 2101d46

File tree

10 files changed

+753
-302
lines changed

10 files changed

+753
-302
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,11 @@ transforms:
5252
quantize_moe:
5353
stage: pattern_matcher
5454
# TODO: Infer sharding parameters (tp_size, row/column sharding) from the model config.
55-
detect_column_row_shard:
55+
detect_sharding:
5656
stage: sharding
5757
simple_shard_only: false
58-
detect_ep_shard:
59-
stage: sharding
60-
detect_dp_bmm_shard:
61-
stage: sharding
58+
use_sharding_from_factory: false
59+
sharding_dims: ['tp', 'ep', 'dp']
6260
# TODO: (hg) need to ensure run_shape_prop after sharding.
6361
sharding_transform_executor:
6462
stage: sharding

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,17 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
159159
"If False, auto-detect and use column+row (all_reduce) sharding when possible.",
160160
)
161161

162+
use_sharding_from_factory: bool = Field(
163+
default=False,
164+
description="If True, use sharding from the model factory. If False, use sharding from the "
165+
"AutoDeployConfig.",
166+
)
167+
168+
sharding_dims: List[str] = Field(
169+
default=["tp", "ep", "dp"],
170+
description="The sharding methods to apply by the heuristic sharding stage.",
171+
)
172+
162173
compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = (
163174
Field(
164175
default="torch-compile",

tensorrt_llm/_torch/auto_deploy/models/factory.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import copy
44
from abc import ABC, abstractmethod
5+
from enum import Enum
56
from typing import Any, Callable, Dict, Optional, Type
67

78
import torch
@@ -12,6 +13,13 @@
1213
from ..utils.logger import ad_logger
1314

1415

16+
class ShardingConfigSource(Enum):
17+
"""Enum for factory source."""
18+
19+
HUGGINGFACE = "huggingface"
20+
UNKNOWN = "unknown"
21+
22+
1523
class ModelFactory(ABC):
1624
"""An interface to return and correctly initialize a model from a desired source.
1725
@@ -38,6 +46,8 @@ def __init__(
3846
self.max_seq_len = max_seq_len
3947
self._prefetched_model_path: Optional[str] = None
4048
self._prefetched_tokenizer_path: Optional[str] = None
49+
self._sharding_config: Dict[str, Any] = {}
50+
self._sharding_config["source"] = ShardingConfigSource.UNKNOWN
4151

4252
@property
4353
def model(self) -> Optional[str]:
@@ -96,6 +106,10 @@ def get_quant_config(self) -> Dict:
96106
"""Returns the quantization config for this model or None if not quantized."""
97107
return {}
98108

109+
def get_sharding_config(self) -> Dict:
110+
"""Returns the sharding config for this model."""
111+
return self._sharding_config
112+
99113
def get_cache_config(self) -> CacheConfig:
100114
"""Return the cache configuration for the model.
101115

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ..custom_ops.attention_interface import CacheConfig
3030
from ..utils._config import deep_merge_dicts
3131
from ..utils.logger import ad_logger
32-
from .factory import ModelFactory, ModelFactoryRegistry
32+
from .factory import ModelFactory, ModelFactoryRegistry, ShardingConfigSource
3333
from .quant_config_reader import QuantConfigReader, QuantConfigReaderRegistry
3434

3535

@@ -94,6 +94,9 @@ def __init__(self, *args, **kwargs):
9494
assert isinstance(dtype, torch.dtype), f"Invalid dtype: {dtype}"
9595
self.model_kwargs["torch_dtype"] = dtype
9696

97+
# set sharding config source to huggingface
98+
self._sharding_config["source"] = ShardingConfigSource.HUGGINGFACE
99+
97100
@property
98101
def autoconfig_from_pretrained(self):
99102
return AutoConfig.from_pretrained
@@ -161,13 +164,30 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module:
161164
if hasattr(model, "post_init"):
162165
model.post_init()
163166

167+
# if present, initialize sharding config. We need head_dim for colwise sharding.
168+
self._set_sharding_config(model.config)
169+
164170
# patch forward method
165171
model.forward = types.MethodType(self._simple_forward, model)
166172

167173
model.eval()
168174

169175
return model
170176

177+
def _set_sharding_config(self, model_config: PretrainedConfig):
178+
"""Set the sharding config for the model."""
179+
self._sharding_config["head_dim"] = 1
180+
if hasattr(model_config, "base_model_tp_plan"):
181+
self._sharding_config["tp_plan"] = model_config.base_model_tp_plan
182+
if hasattr(model_config, "head_dim") and model_config.head_dim is not None:
183+
self._sharding_config["head_dim"] = model_config.head_dim
184+
elif hasattr(model_config, "hidden_size") and hasattr(model_config, "num_attention_heads"):
185+
self._sharding_config["head_dim"] = (
186+
model_config.hidden_size // model_config.num_attention_heads
187+
)
188+
if hasattr(model_config, "num_hidden_layers"):
189+
self._sharding_config["num_hidden_layers"] = model_config.num_hidden_layers
190+
171191
def get_quant_config(self) -> Dict:
172192
"""Returns the quantization config for this model or an empty dict if not quantized."""
173193
if self._quant_config_reader is not None:
@@ -339,6 +359,19 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory):
339359
},
340360
}
341361

362+
def _set_sharding_config(self, model_config: PretrainedConfig):
363+
"""Override the sharding config for the model with text_config."""
364+
super()._set_sharding_config(model_config)
365+
366+
if hasattr(model_config, "text_config"):
367+
text_config = model_config.text_config
368+
if hasattr(text_config, "base_model_tp_plan"):
369+
self._sharding_config["tp_plan"] = text_config.base_model_tp_plan
370+
if hasattr(text_config, "head_dim"):
371+
self._sharding_config["head_dim"] = text_config.head_dim
372+
if hasattr(text_config, "num_hidden_layers"):
373+
self._sharding_config["num_hidden_layers"] = text_config.num_hidden_layers
374+
342375
@property
343376
def automodel_from_config(self):
344377
return AutoModelForImageTextToText.from_config

0 commit comments

Comments
 (0)