-
Notifications
You must be signed in to change notification settings - Fork 279
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introducing a generic
ModelConverter
interface. (#823)
This model handler interface should cover most cases in quantization, fused layer optimization, ... This PR adds: * A generic interface for a `ModelConverter` class, transforming a model; * An argument `model.converters` where the user can add a list of converters to apply to the model (e.g. `float8`) * Converting `Float8Handler` to `ModelConverter` interface. Related issue: #790
- Loading branch information
Showing
16 changed files
with
218 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.float8 import Float8Converter | ||
from torchtitan.model_converter import build_model_converters, ModelConvertersContainer | ||
from torchtitan.parallelisms import ParallelDims | ||
|
||
|
||
def build_parallel_dims(job_config, world_size): | ||
parallel_dims = ParallelDims( | ||
dp_shard=job_config.training.data_parallel_shard_degree, | ||
dp_replicate=job_config.training.data_parallel_replicate_degree, | ||
cp=job_config.experimental.context_parallel_degree, | ||
tp=job_config.training.tensor_parallel_degree, | ||
pp=job_config.experimental.pipeline_parallel_degree, | ||
world_size=world_size, | ||
enable_loss_parallel=not job_config.training.disable_loss_parallel, | ||
) | ||
return parallel_dims | ||
|
||
|
||
def test_build_model_converters_empty_list(): | ||
config = JobConfig() | ||
config.parse_args([]) | ||
parallel_dims = build_parallel_dims(config, 1) | ||
|
||
model_converters = build_model_converters(config, parallel_dims) | ||
assert isinstance(model_converters, ModelConvertersContainer) | ||
assert model_converters.converters == [] | ||
|
||
|
||
def test_build_model_converters_float8_converter(): | ||
config = JobConfig() | ||
config.parse_args(["--model.converters", "float8"]) | ||
parallel_dims = build_parallel_dims(config, 1) | ||
|
||
model_converters = build_model_converters(config, parallel_dims) | ||
assert isinstance(model_converters, ModelConvertersContainer) | ||
assert len(model_converters.converters) == 1 | ||
assert isinstance(model_converters.converters[0], Float8Converter) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from typing import Dict, List, Protocol, Union | ||
|
||
import torch.nn as nn | ||
|
||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.parallelisms import ParallelDims | ||
|
||
|
||
class ModelConverter(Protocol): | ||
"""General model converter interface. | ||
A model converter is applying a modification to PyTorch model. | ||
Typical use cases are: | ||
- Quantization: using QAT, FP8, ... specialized linear layers; | ||
- Fused optimized layers (e.g. flash-attention, norms, ...) | ||
""" | ||
|
||
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | ||
... | ||
|
||
def convert(self, model: nn.Module): | ||
"""Inplace convertion of the model.""" | ||
... | ||
|
||
def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]): | ||
"""Post-optimizer (optional) hook (e.g. compute weights statistics).""" | ||
... | ||
|
||
|
||
_registry_model_converter_cls: Dict[str, type[ModelConverter]] = {} | ||
"""Registry of model converter classes. | ||
""" | ||
|
||
|
||
def register_model_converter(converter_cls: type[ModelConverter], name: str): | ||
"""Register a model converter class. | ||
A registered model converter can be applied on any model | ||
using the `model.converters` config parameter. | ||
""" | ||
assert ( | ||
name not in _registry_model_converter_cls | ||
), f"A model converter '{name}' is already registered." | ||
_registry_model_converter_cls[name] = converter_cls | ||
|
||
|
||
class ModelConvertersContainer(ModelConverter): | ||
"""Model converters sequential container. | ||
The class build the sequence of model converters defined in `model.converters` | ||
job config, and apply them to the model sequentially. | ||
""" | ||
|
||
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | ||
converter_classes = [ | ||
_registry_model_converter_cls[name] for name in job_config.model.converters | ||
] | ||
self.converters = [ | ||
mh_cls(job_config, parallel_dims) for mh_cls in converter_classes | ||
] | ||
|
||
def convert(self, model: nn.Module): | ||
for mh in self.converters: | ||
mh.convert(model) | ||
|
||
def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]): | ||
for mh in self.converters: | ||
mh.post_optimizer_hook(model) | ||
|
||
|
||
def build_model_converters( | ||
job_config: JobConfig, parallel_dims: ParallelDims | ||
) -> ModelConvertersContainer: | ||
"""Build the collection of model converters to apply to the model.""" | ||
return ModelConvertersContainer(job_config, parallel_dims) |
Oops, something went wrong.