diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index b82120a6d..f18a51c17 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -1,4 +1,4 @@ -torch >= 2.2.0.dev +torch >= 2.4.0.dev datasets tomli >= 1.1.0 ; python_version < "3.11" tensorboard diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 1de3c82c9..957c4efb9 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -113,6 +113,12 @@ def __init__(self): default="tb", help="Folder to dump TensorBoard states", ) + self.parser.add_argument( + "--metrics.rank_0_only", + default=True, + action="store_true", + help="Whether to save TensorBoard metrics only for rank 0 or for all ranks", + ) # model configs self.parser.add_argument( diff --git a/torchtitan/metrics.py b/torchtitan/metrics.py index 90108976b..b9b9cabdc 100644 --- a/torchtitan/metrics.py +++ b/torchtitan/metrics.py @@ -113,16 +113,21 @@ def close(self): def build_metric_logger(config: JobConfig, tag: Optional[str] = None): dump_dir = config.job.dump_folder - save_tb_folder = config.metrics.save_tb_folder - # since we don't have run id yet, use current minute as identifier + tb_config = config.metrics + save_tb_folder = tb_config.save_tb_folder + # since we don't have run id, use current minute as the identifier datetime_str = datetime.now().strftime("%Y%m%d-%H%M") log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str) - enable_tb = config.metrics.enable_tensorboard + enable_tb = tb_config.enable_tensorboard if enable_tb: logger.info( f"Metrics logging active. Tensorboard logs will be saved at {log_dir}" ) + if tb_config.rank_0_only: + enable_tb = torch.distributed.get_rank() == 0 + else: + rank_str = f"rank_{torch.distributed.get_rank()}" + log_dir = os.path.join(log_dir, rank_str) - rank_str = f"rank_{torch.distributed.get_rank()}" - return MetricLogger(os.path.join(log_dir, rank_str), tag, enable_tb) + return MetricLogger(log_dir, tag, enable_tb) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 9c8d0a29f..9567cc390 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -174,6 +174,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): # Apply tensor + sequence parallelism to every transformer block for layer_id, transformer_block in model.layers.items(): layer_plan = { + "attention_norm": SequenceParallel(), "attention": PrepareModuleInput( input_layouts=(Shard(1), None), desired_input_layouts=(Replicate(), None), @@ -182,7 +183,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): "attention.wk": col_parallel_strategy(), "attention.wv": col_parallel_strategy(), "attention.wo": row_parallel_strategy(output_layouts=Shard(1)), - "attention_norm": SequenceParallel(), + "ffn_norm": SequenceParallel(), "feed_forward": PrepareModuleInput( input_layouts=(Shard(1),), desired_input_layouts=(Replicate(),), @@ -190,7 +191,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): "feed_forward.w1": col_parallel_strategy(), "feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)), "feed_forward.w3": col_parallel_strategy(), - "ffn_norm": SequenceParallel(), } # Adjust attention module to use the local number of heads diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 4541fec7b..7ed6a5a0a 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -3,7 +3,6 @@ [job] dump_folder = "./outputs" description = "Llama 3 debug training" -# TODO: turn this back on once ci have tokenizer use_for_integration_test = true [profiling] @@ -48,7 +47,7 @@ interval_type = "steps" interval = 5 model_weights_only = false export_dtype = "float32" -async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full']