Skip to content

Commit

Permalink
Update on "Modifying memory estimation options and minor changes"
Browse files Browse the repository at this point in the history
As per suggestions from tianyu-l in #425, the config options are now:
`./run_llama_train.sh --memory_estimation.enabled --memory_estimation.fake_mode_only`

[ghstack-poisoned]
  • Loading branch information
sanketpurandare committed Jul 1, 2024
1 parent a273a49 commit 6dc4cb0
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def estimate_memory(job_config: JobConfig):
# fake tensor doesn't work with fused rmsnorm
if (
job_config.model.norm_type == "fused_rmsnorm"
and job_config.memory_estimation.fake_mode_only
and not job_config.memory_estimation.disable_fake_mode
):
logger.info(
"Fused RMSNorm is not supported yet under fake estimation mode. "
Expand Down Expand Up @@ -111,7 +111,7 @@ def loss_fn(pred, labels):
model_config.vocab_size = tokenizer.n_words
model_config.max_seq_len = job_config.training.seq_len

with FakeTensorMode() if job_config.memory_estimation.fake_mode_only else contextlib.nullcontext():
with FakeTensorMode() if not job_config.memory_estimation.disable_fake_mode else contextlib.nullcontext():

logger.info(
f"Building {model_name} {job_config.model.flavor} with {model_config}"
Expand Down Expand Up @@ -202,7 +202,7 @@ def loss_fn(pred, labels):
f" {peak_reserved / gib} GiB | num_retries: {num_retries}"
)
print(f"Tracker Max: {tracker_peak / gib} GiB")
if not job_config.memory_estimation.fake_mode_only and peak_active > 0:
if job_config.memory_estimation.disable_fake_mode and peak_active > 0:
print(f"Tracker Accuracy: {tracker_peak/peak_active}")
gc.enable()

Expand Down
3 changes: 1 addition & 2 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,11 @@ def build_test_list():
[
[
"--memory_estimation.enabled",
"--memory_estimation.fake_mode_only",
]
],
"FSDP2 Memory Tracking and Estimation",
"fsdp2_mem_tracker",
ngpu=8,
ngpu=4,
),
]
return integration_tests_flavors
Expand Down
3 changes: 2 additions & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,9 @@ def __init__(self):
)

self.parser.add_argument(
"--memory_estimation.fake_mode_only",
"--memory_estimation.disable_fake_mode",
help="Whether to estimate memory under FakeTensorMode",
default=False,
action="store_true",
)

Expand Down

0 comments on commit 6dc4cb0

Please sign in to comment.