From 0c901554397a688b4aa030d4117a4ec3d5876eb5 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 20 May 2024 17:18:54 -0700 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- torchtitan/checkpoint.py | 4 +-- torchtitan/config_manager.py | 31 +++++--------------- torchtitan/parallelisms/parallelize_llama.py | 10 +++---- 3 files changed, 15 insertions(+), 30 deletions(-) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 2e1fdf67..81bdf592 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -22,7 +22,7 @@ set_optimizer_state_dict, ) from torch.distributed.checkpoint.stateful import Stateful -from torchtitan.config_manager import JobConfig +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging_utils import init_logger, logger @@ -134,7 +134,7 @@ def __init__( self.pg = dist.new_group(backend="gloo") self.model_weights_only = ckpt_config.model_weights_only - self.export_dtype = ckpt_config.export_dtype + self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype] self.mp = None async_mode = ckpt_config.async_mode.lower() diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 7484eb93..1e13a677 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -18,22 +18,12 @@ from torchtitan.logging_utils import logger -DTYPE_MAP = { +TORCH_DTYPE_MAP = { "float16": torch.float16, "float32": torch.float32, "bfloat16": torch.bfloat16, } -TORCH_DTYPE_ARGS = [ - "checkpoint.export_dtype", - "training.mixed_precision_param", - "training.mixed_precision_reduce", -] - - -def torch_dtype(dtype_str: str) -> torch.dtype: - return DTYPE_MAP[dtype_str] - def string_list(raw_arg): return raw_arg.split(",") @@ -289,8 +279,8 @@ def __init__(self): ) self.parser.add_argument( "--training.mixed_precision_param", - type=torch_dtype, - default=torch_dtype("bfloat16"), + type=str, + default="bfloat16", choices=["bfloat16", "float32"], help=""" torch dtype to use for parameters when applying mixed precision via FSDP. @@ -299,8 +289,8 @@ def __init__(self): ) self.parser.add_argument( "--training.mixed_precision_reduce", - type=torch_dtype, - default=torch_dtype("float32"), + type=str, + default="float32", choices=["float32"], help=""" torch dtype to use for reductions when applying mixed precision via FSDP. @@ -373,8 +363,8 @@ def __init__(self): ) self.parser.add_argument( "--checkpoint.export_dtype", - type=torch_dtype, - default=torch_dtype("float32"), + type=str, + default="float32", choices=["float16", "bfloat16", "float32"], help=""" Converts to the specified precision when training completes and model_weights_only=true. @@ -462,9 +452,6 @@ def parse_args(self, args_list: list = sys.argv[1:]): try: with open(config_file, "rb") as f: for k, v in tomllib.load(f).items(): - for k_, v_ in v.items(): - if ".".join([k, k_]) in TORCH_DTYPE_ARGS: - v[k_] = torch_dtype(v_) # to prevent overwrite of non-specified keys args_dict[k] |= v except (FileNotFoundError, tomllib.TOMLDecodeError) as e: @@ -508,9 +495,7 @@ def parse_args_from_command_line( # aux parser to parse the command line only args, with no defaults from main parser aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS) for arg, val in vars(args).items(): - if arg in TORCH_DTYPE_ARGS: - aux_parser.add_argument("--" + arg, type=torch_dtype) - elif isinstance(val, bool): + if isinstance(val, bool): aux_parser.add_argument( "--" + arg, action="store_true" if val else "store_false" ) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 8cecdb4a..1265495f 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -33,7 +33,7 @@ from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint -from torchtitan.config_manager import JobConfig +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging_utils import logger from torchtitan.parallelisms.pipelining_utils import split_stage_fqns @@ -226,7 +226,7 @@ def pipeline_llama_manual( int(job_config.training.seq_len // parallel_dims.tp), model_config.dim, ), - dtype=job_config.training.mixed_precision_param + dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] if parallel_dims.dp_enabled else torch.float32, device=device, @@ -251,7 +251,7 @@ def pipeline_llama_manual( int(job_config.training.seq_len // parallel_dims.tp), model_config.dim, ), - dtype=job_config.training.mixed_precision_param + dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] if parallel_dims.dp_enabled else torch.float32, device=device, @@ -392,8 +392,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names mp_policy = MixedPrecisionPolicy( - param_dtype=job_config.training.mixed_precision_param, - reduce_dtype=job_config.training.mixed_precision_reduce, + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) ac_mode = job_config.activation_checkpoint.mode fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}