|
29 | 29 | from ..custom_ops.attention_interface import CacheConfig |
30 | 30 | from ..utils._config import deep_merge_dicts |
31 | 31 | from ..utils.logger import ad_logger |
32 | | -from .factory import ModelFactory, ModelFactoryRegistry |
| 32 | +from .factory import ModelFactory, ModelFactoryRegistry, ShardingConfigSource |
33 | 33 | from .quant_config_reader import QuantConfigReader, QuantConfigReaderRegistry |
34 | 34 |
|
35 | 35 |
|
@@ -94,6 +94,9 @@ def __init__(self, *args, **kwargs): |
94 | 94 | assert isinstance(dtype, torch.dtype), f"Invalid dtype: {dtype}" |
95 | 95 | self.model_kwargs["torch_dtype"] = dtype |
96 | 96 |
|
| 97 | + # set sharding config source to huggingface |
| 98 | + self._sharding_config["source"] = ShardingConfigSource.HUGGINGFACE |
| 99 | + |
97 | 100 | @property |
98 | 101 | def autoconfig_from_pretrained(self): |
99 | 102 | return AutoConfig.from_pretrained |
@@ -161,13 +164,30 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module: |
161 | 164 | if hasattr(model, "post_init"): |
162 | 165 | model.post_init() |
163 | 166 |
|
| 167 | + # if present, initialize sharding config. We need head_dim for colwise sharding. |
| 168 | + self._set_sharding_config(model.config) |
| 169 | + |
164 | 170 | # patch forward method |
165 | 171 | model.forward = types.MethodType(self._simple_forward, model) |
166 | 172 |
|
167 | 173 | model.eval() |
168 | 174 |
|
169 | 175 | return model |
170 | 176 |
|
| 177 | + def _set_sharding_config(self, model_config: PretrainedConfig): |
| 178 | + """Set the sharding config for the model.""" |
| 179 | + self._sharding_config["head_dim"] = 1 |
| 180 | + if hasattr(model_config, "base_model_tp_plan"): |
| 181 | + self._sharding_config["tp_plan"] = model_config.base_model_tp_plan |
| 182 | + if hasattr(model_config, "head_dim") and model_config.head_dim is not None: |
| 183 | + self._sharding_config["head_dim"] = model_config.head_dim |
| 184 | + elif hasattr(model_config, "hidden_size") and hasattr(model_config, "num_attention_heads"): |
| 185 | + self._sharding_config["head_dim"] = ( |
| 186 | + model_config.hidden_size // model_config.num_attention_heads |
| 187 | + ) |
| 188 | + if hasattr(model_config, "num_hidden_layers"): |
| 189 | + self._sharding_config["num_hidden_layers"] = model_config.num_hidden_layers |
| 190 | + |
171 | 191 | def get_quant_config(self) -> Dict: |
172 | 192 | """Returns the quantization config for this model or an empty dict if not quantized.""" |
173 | 193 | if self._quant_config_reader is not None: |
@@ -339,6 +359,19 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory): |
339 | 359 | }, |
340 | 360 | } |
341 | 361 |
|
| 362 | + def _set_sharding_config(self, model_config: PretrainedConfig): |
| 363 | + """Override the sharding config for the model with text_config.""" |
| 364 | + super()._set_sharding_config(model_config) |
| 365 | + |
| 366 | + if hasattr(model_config, "text_config"): |
| 367 | + text_config = model_config.text_config |
| 368 | + if hasattr(text_config, "base_model_tp_plan"): |
| 369 | + self._sharding_config["tp_plan"] = text_config.base_model_tp_plan |
| 370 | + if hasattr(text_config, "head_dim"): |
| 371 | + self._sharding_config["head_dim"] = text_config.head_dim |
| 372 | + if hasattr(text_config, "num_hidden_layers"): |
| 373 | + self._sharding_config["num_hidden_layers"] = text_config.num_hidden_layers |
| 374 | + |
342 | 375 | @property |
343 | 376 | def automodel_from_config(self): |
344 | 377 | return AutoModelForImageTextToText.from_config |
0 commit comments