Skip to content

Commit

Permalink
only produce tensorboard logs on rank 0 by default
Browse files Browse the repository at this point in the history
ghstack-source-id: 1d228f271db275dd229fae61b3ca064141afcacb
Pull Request resolved: #339
  • Loading branch information
tianyu-l committed May 16, 2024
1 parent 847189d commit 0efa88e
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 10 deletions.
1 change: 0 additions & 1 deletion .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
torch >= 2.2.0.dev
datasets
tomli >= 1.1.0 ; python_version < "3.11"
tensorboard
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ def __init__(self):
default="tb",
help="Folder to dump TensorBoard states",
)
self.parser.add_argument(
"--metrics.rank_0_only",
default=True,
action="store_true",
help="Whether to save TensorBoard metrics only for rank 0 or for all ranks",
)

# model configs
self.parser.add_argument(
Expand Down
15 changes: 10 additions & 5 deletions torchtitan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,21 @@ def close(self):

def build_metric_logger(config: JobConfig, tag: Optional[str] = None):
dump_dir = config.job.dump_folder
save_tb_folder = config.metrics.save_tb_folder
# since we don't have run id yet, use current minute as identifier
tb_config = config.metrics
save_tb_folder = tb_config.save_tb_folder
# since we don't have run id, use current minute as the identifier
datetime_str = datetime.now().strftime("%Y%m%d-%H%M")
log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str)

enable_tb = config.metrics.enable_tensorboard
enable_tb = tb_config.enable_tensorboard
if enable_tb:
logger.info(
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}"
)
if tb_config.rank_0_only:
enable_tb = torch.distributed.get_rank() == 0
else:
rank_str = f"rank_{torch.distributed.get_rank()}"
log_dir = os.path.join(log_dir, rank_str)

rank_str = f"rank_{torch.distributed.get_rank()}"
return MetricLogger(os.path.join(log_dir, rank_str), tag, enable_tb)
return MetricLogger(log_dir, tag, enable_tb)
4 changes: 2 additions & 2 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
# Apply tensor + sequence parallelism to every transformer block
for layer_id, transformer_block in model.layers.items():
layer_plan = {
"attention_norm": SequenceParallel(),
"attention": PrepareModuleInput(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
Expand All @@ -182,15 +183,14 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"attention.wk": col_parallel_strategy(),
"attention.wv": col_parallel_strategy(),
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
"attention_norm": SequenceParallel(),
"ffn_norm": SequenceParallel(),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": col_parallel_strategy(),
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)),
"feed_forward.w3": col_parallel_strategy(),
"ffn_norm": SequenceParallel(),
}

# Adjust attention module to use the local number of heads
Expand Down
3 changes: 1 addition & 2 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
[job]
dump_folder = "./outputs"
description = "Llama 3 debug training"
# TODO: turn this back on once ci have tokenizer
use_for_integration_test = true

[profiling]
Expand Down Expand Up @@ -48,7 +47,7 @@ interval_type = "steps"
interval = 5
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = 'selective' # ['none', 'selective', 'full']
Expand Down

0 comments on commit 0efa88e

Please sign in to comment.