Skip to content

Commit

Permalink
Modifying memory estimation options and minor changes
Browse files Browse the repository at this point in the history
ghstack-source-id: ff6be23d4a64c0de121a13fecd3166f677b9d51f
Pull Request resolved: #435
  • Loading branch information
sanketpurandare committed Jun 28, 2024
1 parent cb73810 commit fd33ba2
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 13 deletions.
10 changes: 7 additions & 3 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,18 @@ def estimate_memory(job_config: JobConfig):
# fake tensor doesn't work with fused rmsnorm
if (
job_config.model.norm_type == "fused_rmsnorm"
and job_config.estimate.mode == "fake"
and job_config.memory_estimation.fake_mode_only
):
logger.info(
"Fused RMSNorm is not supported yet under fake estimation mode. "
"Switching to rmsnorm."
)
job_config.model.norm_type = "rmsnorm"

if job_config.training.compile:
logger.info("Compile mode is not supported yet. " "Switching to Eager mode.")
job_config.training.compile = False

parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
Expand Down Expand Up @@ -107,7 +111,7 @@ def loss_fn(pred, labels):
model_config.vocab_size = tokenizer.n_words
model_config.max_seq_len = job_config.training.seq_len

with FakeTensorMode() if job_config.estimate.mode == "fake" else contextlib.nullcontext():
with FakeTensorMode() if job_config.memory_estimation.fake_mode_only else contextlib.nullcontext():

logger.info(
f"Building {model_name} {job_config.model.flavor} with {model_config}"
Expand Down Expand Up @@ -198,7 +202,7 @@ def loss_fn(pred, labels):
f" {peak_reserved / gib} GiB | num_retries: {num_retries}"
)
print(f"Tracker Max: {tracker_peak / gib} GiB")
if job_config.estimate.mode == "real":
if not job_config.memory_estimation.fake_mode_only and peak_active > 0:
print(f"Tracker Accuracy: {tracker_peak/peak_active}")
gc.enable()

Expand Down
2 changes: 1 addition & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ if [ $# -ne 0 ]; then
fi

# Check if --estimate.memory=True is in the arguments
if echo "$overrides" | grep -q -- "--estimate.memory=True"; then
if echo "$overrides" | grep -q -- "--memory_estimation.enabled"; then
# Calculate WORLD_SIZE as the product of NGPU and NNODES
# Export WORLD_SIZE and LOCAL_RANK
export WORLD_SIZE=$((NGPU * NNODES))
Expand Down
7 changes: 5 additions & 2 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,14 @@ def build_test_list():
),
OverrideDefinitions(
[
["--estimate.memory=True", "--estimate.mode=real"],
[
"--memory_estimation.enabled",
"--memory_estimation.fake_mode_only",
]
],
"FSDP2 Memory Tracking and Estimation",
"fsdp2_mem_tracker",
ngpu=4,
ngpu=8,
),
]
return integration_tests_flavors
Expand Down
13 changes: 6 additions & 7 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,18 +480,17 @@ def __init__(self):
help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
)

# estimation mode settings
# memory estimation settings
self.parser.add_argument(
"--estimate.memory",
"--memory_estimation.enabled",
help="Whether to estimate memory usage for FSDP",
default=False,
action="store_true",
)

self.parser.add_argument(
"--estimate.mode",
type=str,
default="fake",
help="Mode of estimation to use ['fake', 'real']",
"--memory_estimation.fake_mode_only",
help="Whether to estimate memory under FakeTensorMode",
action="store_true",
)

def parse_args(self, args_list: list = sys.argv[1:]):
Expand Down

0 comments on commit fd33ba2

Please sign in to comment.