Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
wconstab committed May 21, 2024
2 parents 37130f4 + 0c90155 commit e8b2f59
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 30 deletions.
4 changes: 2 additions & 2 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
set_optimizer_state_dict,
)
from torch.distributed.checkpoint.stateful import Stateful
from torchtitan.config_manager import JobConfig
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import init_logger, logger


Expand Down Expand Up @@ -134,7 +134,7 @@ def __init__(
self.pg = dist.new_group(backend="gloo")

self.model_weights_only = ckpt_config.model_weights_only
self.export_dtype = ckpt_config.export_dtype
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]

self.mp = None
async_mode = ckpt_config.async_mode.lower()
Expand Down
31 changes: 8 additions & 23 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,12 @@

from torchtitan.logging_utils import logger

DTYPE_MAP = {
TORCH_DTYPE_MAP = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}

TORCH_DTYPE_ARGS = [
"checkpoint.export_dtype",
"training.mixed_precision_param",
"training.mixed_precision_reduce",
]


def torch_dtype(dtype_str: str) -> torch.dtype:
return DTYPE_MAP[dtype_str]


def string_list(raw_arg):
return raw_arg.split(",")
Expand Down Expand Up @@ -289,8 +279,8 @@ def __init__(self):
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=torch_dtype,
default=torch_dtype("bfloat16"),
type=str,
default="bfloat16",
choices=["bfloat16", "float32"],
help="""
torch dtype to use for parameters when applying mixed precision via FSDP.
Expand All @@ -299,8 +289,8 @@ def __init__(self):
)
self.parser.add_argument(
"--training.mixed_precision_reduce",
type=torch_dtype,
default=torch_dtype("float32"),
type=str,
default="float32",
choices=["float32"],
help="""
torch dtype to use for reductions when applying mixed precision via FSDP.
Expand Down Expand Up @@ -373,8 +363,8 @@ def __init__(self):
)
self.parser.add_argument(
"--checkpoint.export_dtype",
type=torch_dtype,
default=torch_dtype("float32"),
type=str,
default="float32",
choices=["float16", "bfloat16", "float32"],
help="""
Converts to the specified precision when training completes and model_weights_only=true.
Expand Down Expand Up @@ -462,9 +452,6 @@ def parse_args(self, args_list: list = sys.argv[1:]):
try:
with open(config_file, "rb") as f:
for k, v in tomllib.load(f).items():
for k_, v_ in v.items():
if ".".join([k, k_]) in TORCH_DTYPE_ARGS:
v[k_] = torch_dtype(v_)
# to prevent overwrite of non-specified keys
args_dict[k] |= v
except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
Expand Down Expand Up @@ -508,9 +495,7 @@ def parse_args_from_command_line(
# aux parser to parse the command line only args, with no defaults from main parser
aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
for arg, val in vars(args).items():
if arg in TORCH_DTYPE_ARGS:
aux_parser.add_argument("--" + arg, type=torch_dtype)
elif isinstance(val, bool):
if isinstance(val, bool):
aux_parser.add_argument(
"--" + arg, action="store_true" if val else "store_false"
)
Expand Down
10 changes: 5 additions & 5 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint

from torchtitan.config_manager import JobConfig
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import logger
from torchtitan.parallelisms.pipelining_utils import split_stage_fqns

Expand Down Expand Up @@ -232,7 +232,7 @@ def pipeline_llama_manual(
int(job_config.training.seq_len // parallel_dims.tp),
model_config.dim,
),
dtype=job_config.training.mixed_precision_param
dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param]
if parallel_dims.dp_enabled
else torch.float32,
device=device,
Expand All @@ -257,7 +257,7 @@ def pipeline_llama_manual(
int(job_config.training.seq_len // parallel_dims.tp),
model_config.dim,
),
dtype=job_config.training.mixed_precision_param
dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param]
if parallel_dims.dp_enabled
else torch.float32,
device=device,
Expand Down Expand Up @@ -398,8 +398,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
mp_policy = MixedPrecisionPolicy(
param_dtype=job_config.training.mixed_precision_param,
reduce_dtype=job_config.training.mixed_precision_reduce,
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)
ac_mode = job_config.activation_checkpoint.mode
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
Expand Down

0 comments on commit e8b2f59

Please sign in to comment.