Skip to content

Commit

Permalink
Introducing a generic ModelConverter interface. (#823)
Browse files Browse the repository at this point in the history
This model handler interface should cover most cases in quantization,
fused layer optimization, ...

This PR adds:
* A generic interface for a `ModelConverter` class, transforming a
model;
* An argument `model.converters` where the user can add a list of
converters to apply to the model (e.g. `float8`)
* Converting `Float8Handler` to `ModelConverter` interface.

Related issue: #790
  • Loading branch information
balancap authored Feb 15, 2025
1 parent f640689 commit 57387af
Show file tree
Hide file tree
Showing 16 changed files with 218 additions and 57 deletions.
4 changes: 2 additions & 2 deletions docs/float8.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git

Launch training job with the following command (or alternatively set configs in toml files)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
```
* `--float8.enable_float8_linear`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
* `--model.converters="float8"`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth.
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.

Expand Down
17 changes: 9 additions & 8 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
import os

import torch

import torchtitan.float8 # noqa
from torch._guards import active_fake_mode
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.testing._internal.distributed.fake_pg import FakeStore

from torchtitan import utils
from torchtitan.config_manager import JobConfig
from torchtitan.float8 import Float8Handler
from torchtitan.logging import init_logger, logger
from torchtitan.model_converter import build_model_converters
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.parallelisms import ParallelDims
from torchtitan.train_spec import get_train_spec
Expand Down Expand Up @@ -117,10 +118,9 @@ def loss_fn(pred, labels):
with torch.device("meta"):
model = model_cls.from_model_args(model_config)

# a no-op hander if float8 is not enabled
float8_handler = Float8Handler(job_config, parallel_dims)
# swap to Float8Linear based on float8 configs
float8_handler.convert_to_float8_training(model)
# Build the collection of model converters. No-op if `model.converters` empty
model_converters = build_model_converters(job_config, parallel_dims)
model_converters.convert(model)

# apply PT-D DP/TP parallelisms and activation checkpointing
train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
Expand Down Expand Up @@ -170,9 +170,10 @@ def loss_fn(pred, labels):
# optimizer step
optimizers.step()
lr_schedulers.step()
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# Post-optimizer model converters hook.
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model)
model_converters.post_optimizer_hook(model)
optimizers.zero_grad()
print(f"Peak Memory at iter: {iter_idx}")
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)
Expand Down
9 changes: 9 additions & 0 deletions tests/unit_tests/test_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,15 @@ def test_parse_exclude_from_loading(self):
config.checkpoint.exclude_from_loading == cmdline_splits
), config.checkpoint.exclude_from_loading

def test_job_config_model_converters_split(self):
config = JobConfig()
config.parse_args([])
assert config.model.converters == []

config = JobConfig()
config.parse_args(["--model.converters", "float8,mxfp"])
assert config.model.converters == ["float8", "mxfp"]

def test_print_help(self):
config = JobConfig()
parser = config.parser
Expand Down
44 changes: 44 additions & 0 deletions tests/unit_tests/test_model_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.config_manager import JobConfig
from torchtitan.float8 import Float8Converter
from torchtitan.model_converter import build_model_converters, ModelConvertersContainer
from torchtitan.parallelisms import ParallelDims


def build_parallel_dims(job_config, world_size):
parallel_dims = ParallelDims(
dp_shard=job_config.training.data_parallel_shard_degree,
dp_replicate=job_config.training.data_parallel_replicate_degree,
cp=job_config.experimental.context_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=not job_config.training.disable_loss_parallel,
)
return parallel_dims


def test_build_model_converters_empty_list():
config = JobConfig()
config.parse_args([])
parallel_dims = build_parallel_dims(config, 1)

model_converters = build_model_converters(config, parallel_dims)
assert isinstance(model_converters, ModelConvertersContainer)
assert model_converters.converters == []


def test_build_model_converters_float8_converter():
config = JobConfig()
config.parse_args(["--model.converters", "float8"])
parallel_dims = build_parallel_dims(config, 1)

model_converters = build_model_converters(config, parallel_dims)
assert isinstance(model_converters, ModelConvertersContainer)
assert len(model_converters.converters) == 1
assert isinstance(model_converters.converters[0], Float8Converter)
3 changes: 3 additions & 0 deletions torchtitan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

# Import to register Float8Converter.
import torchtitan.float8 # noqa: F401

# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models # noqa: F401
70 changes: 39 additions & 31 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,22 @@


def string_list(raw_arg):
"""Comma-separated string list argument."""
return [s.strip() for s in raw_arg.split(",") if s.strip()]


def check_string_list_argument(args_dict: dict[str, any], fullargname: str):
section, name = fullargname.split(".")
# Split string list which are still raw strings.
if (
section in args_dict
and name in args_dict[section]
and isinstance(args_dict[section][name], str)
):
sec = args_dict[section]
sec[name] = string_list(sec[name])


class JobConfig:
"""
A helper class to manage the train configuration.
Expand Down Expand Up @@ -183,6 +196,19 @@ def __init__(self):
default="./torchtitan/datasets/tokenizer/tokenizer.model",
help="Tokenizer path",
)
self.parser.add_argument(
"--model.converters",
type=string_list,
nargs="+",
default=[],
help="""
Comma separated list of converters to apply to the model.
For instance, the `float8` converter swaps `torch.nn.Linear`
with `Float8Linear`. This feature requires you to install 'torchao'
which can be found here: https://github.com/pytorch/ao
""",
)

# optimizer configs
self.parser.add_argument(
Expand Down Expand Up @@ -575,15 +601,6 @@ def __init__(self):
)

# float8 configs
self.parser.add_argument(
"--float8.enable_float8_linear",
action="store_true",
help="""
If true, swaps `torch.nn.Linear` with `Float8Linear`.
This feature requires you to install 'torchao' which can be found
here: https://github.com/pytorch/ao
""",
)
self.parser.add_argument(
"--float8.enable_fsdp_float8_all_gather",
action="store_true",
Expand Down Expand Up @@ -652,25 +669,11 @@ def parse_args(self, args_list: list = sys.argv[1:]):
logger.exception(f"Error details: {str(e)}")
raise e

# Checking string-list arguments are properly split into a list
# if split-points came from 'args' (from cmd line) it would have already been parsed into a list by that parser
if (
"experimental" in args_dict
and "pipeline_parallel_split_points" in args_dict["experimental"]
and isinstance(
args_dict["experimental"]["pipeline_parallel_split_points"], str
)
):
exp = args_dict["experimental"]
exp["pipeline_parallel_split_points"] = string_list(
exp["pipeline_parallel_split_points"]
)
if (
"checkpoint" in args_dict
and "exclude_from_loading" in args_dict["checkpoint"]
and isinstance(args_dict["checkpoint"]["exclude_from_loading"], str)
):
ckpt = args_dict["checkpoint"]
ckpt["exclude_from_loading"] = string_list(ckpt["exclude_from_loading"])
string_list_argnames = self._get_string_list_argument_names()
for n in string_list_argnames:
check_string_list_argument(args_dict, n)

# override args dict with cmd_args
cmd_args_dict = self._args_to_two_level_dict(cmd_args)
Expand Down Expand Up @@ -698,13 +701,21 @@ def _validate_config(self) -> None:
assert self.model.flavor
assert self.model.tokenizer_path

def _get_string_list_argument_names(self) -> list[str]:
"""Get the parser argument names of type `string_list`."""
string_list_args = [
v.dest for v in self.parser._actions if v.type is string_list
]
return string_list_args

def parse_args_from_command_line(
self, args_list
) -> Tuple[argparse.Namespace, argparse.Namespace]:
"""
Parse command line arguments and return the parsed args and the command line only args
"""
args = self.parser.parse_args(args_list)
string_list_argnames = set(self._get_string_list_argument_names())

# aux parser to parse the command line only args, with no defaults from main parser
aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
Expand All @@ -713,14 +724,11 @@ def parse_args_from_command_line(
aux_parser.add_argument(
"--" + arg, action="store_true" if val else "store_false"
)
elif arg == "experimental.pipeline_parallel_split_points":
elif arg in string_list_argnames:
# without this special case, type inference breaks here,
# since the inferred type is just 'list' and it ends up flattening
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
aux_parser.add_argument("--" + arg, type=string_list)
elif arg == "checkpoint.exclude_from_loading":
# similar to the case above
aux_parser.add_argument("--" + arg, type=string_list)
else:
aux_parser.add_argument("--" + arg, type=type(val))

Expand Down
14 changes: 11 additions & 3 deletions torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
from torchtitan.model_converter import ModelConverter, register_model_converter
from torchtitan.parallelisms import ParallelDims


Expand All @@ -28,13 +29,11 @@ def _is_sm89_or_later():
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)


class Float8Handler:
class Float8Converter(ModelConverter):
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
self.enabled = False

float8_config = job_config.float8
if not float8_config.enable_float8_linear:
return
if not _is_sm89_or_later():
logger.warning(
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later",
Expand Down Expand Up @@ -66,6 +65,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):

logger.info("Float8 training active")

def convert(self, model: nn.Module):
return self.convert_to_float8_training(model)

def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
return self.precompute_float8_dynamic_scale_for_fsdp(model)

def convert_to_float8_training(self, model: nn.Module):
"""
This function converts the linear layers of `model` to `Float8Linear`.
Expand Down Expand Up @@ -102,3 +107,6 @@ def precompute_float8_dynamic_scale_for_fsdp(
models = [model] if isinstance(model, nn.Module) else model
for m in models:
precompute_float8_dynamic_scale_for_fsdp(m)


register_model_converter(Float8Converter, "float8")
1 change: 1 addition & 0 deletions torchtitan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
from torch.utils.tensorboard import SummaryWriter

from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
from torchtitan.parallelisms import ParallelDims
Expand Down
80 changes: 80 additions & 0 deletions torchtitan/model_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Protocol, Union

import torch.nn as nn

from torchtitan.config_manager import JobConfig
from torchtitan.parallelisms import ParallelDims


class ModelConverter(Protocol):
"""General model converter interface.
A model converter is applying a modification to PyTorch model.
Typical use cases are:
- Quantization: using QAT, FP8, ... specialized linear layers;
- Fused optimized layers (e.g. flash-attention, norms, ...)
"""

def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
...

def convert(self, model: nn.Module):
"""Inplace convertion of the model."""
...

def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
"""Post-optimizer (optional) hook (e.g. compute weights statistics)."""
...


_registry_model_converter_cls: Dict[str, type[ModelConverter]] = {}
"""Registry of model converter classes.
"""


def register_model_converter(converter_cls: type[ModelConverter], name: str):
"""Register a model converter class.
A registered model converter can be applied on any model
using the `model.converters` config parameter.
"""
assert (
name not in _registry_model_converter_cls
), f"A model converter '{name}' is already registered."
_registry_model_converter_cls[name] = converter_cls


class ModelConvertersContainer(ModelConverter):
"""Model converters sequential container.
The class build the sequence of model converters defined in `model.converters`
job config, and apply them to the model sequentially.
"""

def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
converter_classes = [
_registry_model_converter_cls[name] for name in job_config.model.converters
]
self.converters = [
mh_cls(job_config, parallel_dims) for mh_cls in converter_classes
]

def convert(self, model: nn.Module):
for mh in self.converters:
mh.convert(model)

def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
for mh in self.converters:
mh.post_optimizer_hook(model)


def build_model_converters(
job_config: JobConfig, parallel_dims: ParallelDims
) -> ModelConvertersContainer:
"""Build the collection of model converters to apply to the model."""
return ModelConvertersContainer(job_config, parallel_dims)
Loading

0 comments on commit 57387af

Please sign in to comment.