diff --git a/estimation.py b/estimation.py index 70bd3e60..e82a7b71 100644 --- a/estimation.py +++ b/estimation.py @@ -49,7 +49,7 @@ def estimate_memory(job_config: JobConfig): # fake tensor doesn't work with fused rmsnorm if ( job_config.model.norm_type == "fused_rmsnorm" - and job_config.estimate.mode == "fake" + and not job_config.memory_estimation.disable_fake_mode ): logger.info( "Fused RMSNorm is not supported yet under fake estimation mode. " @@ -57,6 +57,10 @@ def estimate_memory(job_config: JobConfig): ) job_config.model.norm_type = "rmsnorm" + if job_config.training.compile: + logger.info("Compile mode is not supported yet. " "Switching to Eager mode.") + job_config.training.compile = False + parallel_dims = ParallelDims( dp=job_config.training.data_parallel_degree, tp=job_config.training.tensor_parallel_degree, @@ -107,7 +111,7 @@ def loss_fn(pred, labels): model_config.vocab_size = tokenizer.n_words model_config.max_seq_len = job_config.training.seq_len - with FakeTensorMode() if job_config.estimate.mode == "fake" else contextlib.nullcontext(): + with FakeTensorMode() if not job_config.memory_estimation.disable_fake_mode else contextlib.nullcontext(): logger.info( f"Building {model_name} {job_config.model.flavor} with {model_config}" @@ -198,7 +202,7 @@ def loss_fn(pred, labels): f" {peak_reserved / gib} GiB | num_retries: {num_retries}" ) print(f"Tracker Max: {tracker_peak / gib} GiB") - if job_config.estimate.mode == "real": + if job_config.memory_estimation.disable_fake_mode and peak_active > 0: print(f"Tracker Accuracy: {tracker_peak/peak_active}") gc.enable() diff --git a/run_llama_train.sh b/run_llama_train.sh index ca2001a8..cf4943a6 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -31,7 +31,7 @@ if [ $# -ne 0 ]; then fi # Check if --estimate.memory=True is in the arguments -if echo "$overrides" | grep -q -- "--estimate.memory=True"; then +if echo "$overrides" | grep -q -- "--memory_estimation.enabled"; then # Calculate WORLD_SIZE as the product of NGPU and NNODES # Export WORLD_SIZE and LOCAL_RANK export WORLD_SIZE=$((NGPU * NNODES)) diff --git a/test_runner.py b/test_runner.py index 63377edd..cba63544 100755 --- a/test_runner.py +++ b/test_runner.py @@ -265,7 +265,9 @@ def build_test_list(): ), OverrideDefinitions( [ - ["--estimate.memory=True", "--estimate.mode=real"], + [ + "--memory_estimation.enabled", + ] ], "FSDP2 Memory Tracking and Estimation", "fsdp2_mem_tracker", diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 2ff216e1..6930bb7c 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -480,18 +480,18 @@ def __init__(self): help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled", ) - # estimation mode settings + # memory estimation settings self.parser.add_argument( - "--estimate.memory", + "--memory_estimation.enabled", help="Whether to estimate memory usage for FSDP", - default=False, + action="store_true", ) self.parser.add_argument( - "--estimate.mode", - type=str, - default="fake", - help="Mode of estimation to use ['fake', 'real']", + "--memory_estimation.disable_fake_mode", + help="Whether to estimate memory under FakeTensorMode", + default=False, + action="store_true", ) def parse_args(self, args_list: list = sys.argv[1:]):