From 6a8455ec41457f2d8dd8e86f26c93b77a5cd0c06 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 28 May 2024 22:08:56 -0700 Subject: [PATCH] only produce tensorboard logs on rank 0 by default ghstack-source-id: 4255cc792b9a221bc5a012e91db92533dcfa2243 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/339 --- torchtitan/config_manager.py | 6 ++++++ torchtitan/metrics.py | 15 ++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index da80b425..6a730dcb 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -125,6 +125,12 @@ def __init__(self): default="tb", help="Folder to dump TensorBoard states", ) + self.parser.add_argument( + "--metrics.rank_0_only", + default=True, + action="store_true", + help="Whether to save TensorBoard metrics only for rank 0 or for all ranks", + ) # model configs self.parser.add_argument( diff --git a/torchtitan/metrics.py b/torchtitan/metrics.py index 90108976..b9b9cabd 100644 --- a/torchtitan/metrics.py +++ b/torchtitan/metrics.py @@ -113,16 +113,21 @@ def close(self): def build_metric_logger(config: JobConfig, tag: Optional[str] = None): dump_dir = config.job.dump_folder - save_tb_folder = config.metrics.save_tb_folder - # since we don't have run id yet, use current minute as identifier + tb_config = config.metrics + save_tb_folder = tb_config.save_tb_folder + # since we don't have run id, use current minute as the identifier datetime_str = datetime.now().strftime("%Y%m%d-%H%M") log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str) - enable_tb = config.metrics.enable_tensorboard + enable_tb = tb_config.enable_tensorboard if enable_tb: logger.info( f"Metrics logging active. Tensorboard logs will be saved at {log_dir}" ) + if tb_config.rank_0_only: + enable_tb = torch.distributed.get_rank() == 0 + else: + rank_str = f"rank_{torch.distributed.get_rank()}" + log_dir = os.path.join(log_dir, rank_str) - rank_str = f"rank_{torch.distributed.get_rank()}" - return MetricLogger(os.path.join(log_dir, rank_str), tag, enable_tb) + return MetricLogger(log_dir, tag, enable_tb)