Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modifying memory estimation options and minor changes #435

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,18 @@ def estimate_memory(job_config: JobConfig):
# fake tensor doesn't work with fused rmsnorm
if (
job_config.model.norm_type == "fused_rmsnorm"
and job_config.estimate.mode == "fake"
and job_config.memory_estimation.fake_mode_only
):
logger.info(
"Fused RMSNorm is not supported yet under fake estimation mode. "
"Switching to rmsnorm."
)
job_config.model.norm_type = "rmsnorm"

if job_config.training.compile:
logger.info("Compile mode is not supported yet. " "Switching to Eager mode.")
job_config.training.compile = False

parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
Expand Down Expand Up @@ -107,7 +111,7 @@ def loss_fn(pred, labels):
model_config.vocab_size = tokenizer.n_words
model_config.max_seq_len = job_config.training.seq_len

with FakeTensorMode() if job_config.estimate.mode == "fake" else contextlib.nullcontext():
with FakeTensorMode() if job_config.memory_estimation.fake_mode_only else contextlib.nullcontext():

logger.info(
f"Building {model_name} {job_config.model.flavor} with {model_config}"
Expand Down Expand Up @@ -198,7 +202,7 @@ def loss_fn(pred, labels):
f" {peak_reserved / gib} GiB | num_retries: {num_retries}"
)
print(f"Tracker Max: {tracker_peak / gib} GiB")
if job_config.estimate.mode == "real":
if not job_config.memory_estimation.fake_mode_only and peak_active > 0:
print(f"Tracker Accuracy: {tracker_peak/peak_active}")
gc.enable()

Expand Down
2 changes: 1 addition & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ if [ $# -ne 0 ]; then
fi

# Check if --estimate.memory=True is in the arguments
if echo "$overrides" | grep -q -- "--estimate.memory=True"; then
if echo "$overrides" | grep -q -- "--memory_estimation.enabled"; then
# Calculate WORLD_SIZE as the product of NGPU and NNODES
# Export WORLD_SIZE and LOCAL_RANK
export WORLD_SIZE=$((NGPU * NNODES))
Expand Down
7 changes: 5 additions & 2 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,14 @@ def build_test_list():
),
OverrideDefinitions(
[
["--estimate.memory=True", "--estimate.mode=real"],
[
"--memory_estimation.enabled",
"--memory_estimation.fake_mode_only",
]
],
"FSDP2 Memory Tracking and Estimation",
"fsdp2_mem_tracker",
ngpu=4,
ngpu=8,
sanketpurandare marked this conversation as resolved.
Show resolved Hide resolved
),
]
return integration_tests_flavors
Expand Down
13 changes: 6 additions & 7 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,18 +480,17 @@ def __init__(self):
help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
)

# estimation mode settings
# memory estimation settings
self.parser.add_argument(
"--estimate.memory",
"--memory_estimation.enabled",
help="Whether to estimate memory usage for FSDP",
default=False,
action="store_true",
)

self.parser.add_argument(
"--estimate.mode",
type=str,
default="fake",
help="Mode of estimation to use ['fake', 'real']",
"--memory_estimation.fake_mode_only",
help="Whether to estimate memory under FakeTensorMode",
action="store_true",
sanketpurandare marked this conversation as resolved.
Show resolved Hide resolved
)

def parse_args(self, args_list: list = sys.argv[1:]):
Expand Down
Loading