From 3bd14ec4479baf7f4b24720088b996188f3b52e3 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 20 May 2024 17:18:53 -0700 Subject: [PATCH] Expose mixed_precision dtype arguments add training.mixed_precision_param and .mixed_precision_reduce options refactor a util to map strings to torch dtypes ghstack-source-id: 387e1ca13ad23e859d21d7760f858ee6e269a796 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/348 --- torchtitan/checkpoint.py | 11 ++------ torchtitan/config_manager.py | 29 ++++++++++++++++++++ torchtitan/parallelisms/parallelize_llama.py | 6 ++-- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 33fe8c05..81bdf592 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -22,17 +22,10 @@ set_optimizer_state_dict, ) from torch.distributed.checkpoint.stateful import Stateful -from torchtitan.config_manager import JobConfig +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging_utils import init_logger, logger -DTYPE_MAP = { - "float16": torch.float16, - "float32": torch.float32, - "bfloat16": torch.bfloat16, -} - - class IntervalType(enum.Enum): SECONDS = enum.auto() STEPS = enum.auto() @@ -141,7 +134,7 @@ def __init__( self.pg = dist.new_group(backend="gloo") self.model_weights_only = ckpt_config.model_weights_only - self.export_dtype = DTYPE_MAP[ckpt_config.export_dtype] + self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype] self.mp = None async_mode = ckpt_config.async_mode.lower() diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 1de3c82c..1a3e36d4 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -9,6 +9,8 @@ from collections import defaultdict from typing import Tuple, Union +import torch + try: import tomllib except ModuleNotFoundError: @@ -16,6 +18,12 @@ from torchtitan.logging_utils import logger +TORCH_DTYPE_MAP = { + "float16": torch.float16, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + class JobConfig: """ @@ -207,6 +215,26 @@ def __init__(self): default=1, help="Pipeline Parallelism degree. 1 means disabled.", ) + self.parser.add_argument( + "--training.mixed_precision_param", + type=str, + default="bfloat16", + choices=["bfloat16", "float32"], + help=""" + torch dtype to use for parameters when applying mixed precision via FSDP. + This feature only takes effect when data_parallel_degree > 1 + """, + ) + self.parser.add_argument( + "--training.mixed_precision_reduce", + type=str, + default="float32", + choices=["float32"], + help=""" + torch dtype to use for reductions when applying mixed precision via FSDP. + This feature only takes effect when data_parallel_degree > 1 + """, + ) self.parser.add_argument( "--training.compile", action="store_true", @@ -275,6 +303,7 @@ def __init__(self): "--checkpoint.export_dtype", type=str, default="float32", + choices=["float16", "bfloat16", "float32"], help=""" Converts to the specified precision when training completes and model_weights_only=true. Currently supports float32, float16, and bfloat16. diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 9c8d0a29..0bd0a966 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -28,7 +28,7 @@ from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint -from torchtitan.config_manager import JobConfig +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging_utils import logger @@ -209,9 +209,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names - # TODO: Expose `reduce_dtype` as a config option. mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32 + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) ac_mode = job_config.activation_checkpoint.mode fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}