Skip to content

Commit

Permalink
W&B wandb support (#699)
Browse files Browse the repository at this point in the history
This PR adds experimental wandb support, not sure this is "landable"
considering y'all uses tensorboard by default. Personally I vastly
prefer wandb because I can share my training runs with a link and don't
need to muck around with ssh tunneling so I'm just opening this since
I'm using it myself. If there's interest I can work to land this.

To use this you just kick of a training as usual with
`CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh
` but also run `wandb login` and paste in your token
![Screenshot 2024-11-25 at 12 16
20 PM](https://github.com/user-attachments/assets/4d8c3893-2bb2-435e-b9bb-69558b8ea7ea)

Changes in logs will look like

![Screenshot 2024-11-25 at 12 27
42 PM](https://github.com/user-attachments/assets/24760ef3-21d5-4292-bdef-cddb5e916e6b)

Also only slightly related but llama 3 tokenizer is not available on hf
anymore so added instructions for 3.1 and 3.2


<details>
<summary>Click here for detailed logs.</summary>
[rank0]:2024-11-25 11:33:24,320 - root - INFO - Dumping traces at step
1000
[rank0]:2024-11-25 11:33:24,576 - root - INFO - Finished dumping traces
in 0.26 seconds
[rank0]:2024-11-25 11:33:24,577 - root - INFO - Sleeping 2 seconds for
other ranks to complete
[rank0]:wandb:
[rank0]:wandb: 
[rank0]:wandb: Run history:
[rank0]:wandb: loss_metrics/global_avg_loss
█▆▅▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[rank0]:wandb: loss_metrics/global_max_loss
█▇▄▄▃▃▄▃▃▆▃▃▃▃▃▃▂▂▂▂▃▂▂▃▁▂▂▂▁▃▂▁▂▁▂▂▁▄▁▁
[rank0]:wandb: memory/max_active(%)
▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[rank0]:wandb: memory/max_active(GiB)
▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[rank0]:wandb: memory/max_reserved(%)
▁███████████████████████████████████████
[rank0]:wandb: memory/max_reserved(GiB)
▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[rank0]:wandb: memory/num_alloc_retries
▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[rank0]:wandb: memory/num_ooms ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[rank0]:wandb: mfu(%) ▁███████▇██████▇█████████▇█▇████████████
[rank0]:wandb: step ▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▆▆▆▆▇▇▇▇▇▇▇████
[rank0]:wandb: time_metrics/data_loading(%)
▁▁▁▁▂▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▂▁▁▂▁▁▁▂
[rank0]:wandb: time_metrics/data_loading(s)
▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂
[rank0]:wandb: time_metrics/end_to_end(s)
▁▇▇▇▇█▇▇▇█▇▇▇▇▇▇▇▇▇▇▇▇██▇▇▇█▇▇▇▇▇▇█▇█▇▇▇
[rank0]:wandb: wps ███▁████▄█▇▅████████▅▄████▇███▇▄████▇██▇
[rank0]:wandb: 
[rank0]:wandb: Run summary:
[rank0]:wandb: loss_metrics/global_avg_loss 4.53519
[rank0]:wandb: loss_metrics/global_max_loss 4.99517
[rank0]:wandb:         memory/max_active(%) 43.33611
[rank0]:wandb:       memory/max_active(GiB) 41.17145
[rank0]:wandb:       memory/max_reserved(%) 52.19301
[rank0]:wandb:     memory/max_reserved(GiB) 49.58594
[rank0]:wandb:     memory/num_alloc_retries 0
[rank0]:wandb:              memory/num_ooms 0
[rank0]:wandb:                       mfu(%) 30.75216
[rank0]:wandb:                         step 1000
[rank0]:wandb: time_metrics/data_loading(%) 1.01461
[rank0]:wandb: time_metrics/data_loading(s) 0.01583
[rank0]:wandb:   time_metrics/end_to_end(s) 1.55993
[rank0]:wandb:                          wps 5251.52034
[rank0]:wandb: 
[rank0]:wandb: 🚀 View run skilled-glitter-1 at:
https://wandb.ai/sahancpal-meta/torchtitan/runs/r1zqr75b
</details>

---------

Co-authored-by: tianyu-l <[email protected]>
  • Loading branch information
msaroufim and tianyu-l authored Dec 3, 2024
1 parent 9a4eebe commit 6a6f755
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 66 deletions.
1 change: 1 addition & 0 deletions .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ sentencepiece
tiktoken
blobfile
tabulate
wandb
24 changes: 2 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ You may want to see how the model is defined or how parallelism techniques are a
6. DDP and HSDP
7. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries)
8. Learning rate scheduler, meta-init, (optional) fused RMSNorm kernel
9. Loss, GPU memory, throughput (tokens/sec), and MFU displayed and logged via [TensorBoard](#tensorboard)
9. Loss, GPU memory, throughput (tokens/sec), and MFU displayed and logged via [Tensorboard or Weights & Biases](/docs/metrics.md)
10. Debugging tools including CPU/GPU profiling, [memory profiling](docs/memory_profiler.md), [Flight Recorder](#debugging), etc.
11. All options easily configured via [toml files](train_configs/)

Expand All @@ -73,7 +73,7 @@ We report our [Performance](docs/performance.md) verified on 64/128 GPUs.
git clone https://github.com/pytorch/torchtitan
cd torchtitan
pip install -r requirements.txt
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall # or cu118
```

### Downloading a tokenizer
Expand All @@ -99,26 +99,6 @@ Llama 3 8B model locally on 8 GPUs
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh
```


## TensorBoard

To visualize TensorBoard metrics of models trained on a remote server via a local web browser:

1. Make sure `metrics.enable_tensorboard` option is set to true in model training (either from a .toml file or from CLI).

2. Set up SSH tunneling, by running the following from local CLI
```
ssh -L 6006:127.0.0.1:6006 [username]@[hostname]
```

3. Inside the SSH tunnel that logged into the remote server, go to the torchtitan repo, and start the TensorBoard backend
```
tensorboard --logdir=./outputs/tb
```

4. In the local web browser, go to the URL it provides OR to http://localhost:6006/.


## Multi-Node Training
For training on ParallelCluster/Slurm type configurations, you can use the `multinode_trainer.slurm` file to submit your sbatch job.

Expand Down
36 changes: 36 additions & 0 deletions docs/metrics.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Metrics

We support automatically collecting metrics such as
1. High level system metrics such as MFU, average loss, max loss and words per second along with some
2. Memory metrics to measure max VRAM consumption and the number of OOMs
3. Timing metrics to measure data loading bottlenecks

Those metrics can then be visualized in either a TensorBoard or WandDB dashboard

## TensorBoard

To visualize TensorBoard metrics of models trained on a remote server via a local web browser:

1. Make sure `metrics.enable_tensorboard` option is set to true in model training (either from a .toml file or from CLI).

2. Set up SSH tunneling, by running the following from local CLI
```
ssh -L 6006:127.0.0.1:6006 [username]@[hostname]
```

3. Inside the SSH tunnel that logged into the remote server, go to the torchtitan repo, and start the TensorBoard backend
```
tensorboard --logdir=./outputs/tb
```

4. In the local web browser, go to the URL it provides OR to http://localhost:6006/.

## Weights and Biases

Weights and Biases will automatically send metrics to a remote server if you login with `wandb login`

So all you need to do is make sure that `metrics.enable_wandb` is enabled

For an example you can inspect [debug_model.toml](../train_configs/debug_model.toml)

Note that if both W&B and Tensorboard are enabled then we will prioritize W&B.
19 changes: 13 additions & 6 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,16 @@ def __init__(self):
help="How often to log metrics to TensorBoard, in iterations",
)
self.parser.add_argument(
"--metrics.enable_color_printing",
default=False,
"--metrics.enable_tensorboard",
action="store_true",
help="Whether to enable color printing",
default=False,
help="Whether to log metrics to TensorBoard",
)
self.parser.add_argument(
"--metrics.enable_tensorboard",
"--metrics.enable_color_printing",
action="store_true",
help="Whether to log metrics to TensorBoard",
default=True,
help="Whether to enable color printing in logs",
)
self.parser.add_argument(
"--metrics.save_tb_folder",
Expand All @@ -139,14 +140,20 @@ def __init__(self):
)
self.parser.add_argument(
"--metrics.rank_0_only",
default=True,
action="store_true",
default=True,
help="""
Whether to save TensorBoard metrics only for rank 0 or for all ranks.
When pipeline_parallel_degree is > 1, this option uses the 0th rank of the last stage pipeline group,
which is the only stage that computes loss metrics.
""",
)
self.parser.add_argument(
"--metrics.enable_wandb",
action="store_true",
default=False,
help="Whether to log metrics to Weights & Biases",
)

# model configs
self.parser.add_argument(
Expand Down
146 changes: 108 additions & 38 deletions torchtitan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from torchtitan.parallelisms import ParallelDims
from torchtitan.utils import device_module, device_type


# named tuple for passing device memory stats for logging
DeviceMemStats = namedtuple(
"DeviceMemStats",
Expand Down Expand Up @@ -90,29 +89,66 @@ def reset_peak_stats(self):
def build_device_memory_monitor():
device_memory_monitor = DeviceMemoryMonitor(device_type)
logger.info(
f"{device_type.upper()} capacity: {device_memory_monitor.device_name} ({device_memory_monitor.device_index}) "
f"{device_type.upper()} capacity: {device_memory_monitor.device_name}"
f"with {device_memory_monitor.device_capacity_gib:.2f}GiB memory"
)

return device_memory_monitor


class MetricLogger:
def __init__(self, log_dir, tag, enable_tb):
class BaseLogger:
"""Logger that does nothing, used when logging is disabled."""

def log(self, metrics: Dict[str, Any], step: int) -> None:
pass

def close(self) -> None:
pass


class TensorBoardLogger(BaseLogger):
"""Logger implementation for TensorBoard."""

def __init__(self, log_dir: str, tag: Optional[str] = None):
self.tag = tag
self.writer = SummaryWriter(log_dir, max_queue=1000)
logger.info(f"TensorBoard logging enabled. Logs will be saved at {log_dir}")

def log(self, metrics: Dict[str, Any], step: int) -> None:
for k, v in metrics.items():
tag = k if self.tag is None else f"{self.tag}/{k}"
self.writer.add_scalar(tag, v, step)

def close(self) -> None:
self.writer.close()


class WandBLogger(BaseLogger):
"""Logger implementation for Weights & Biases."""

def __init__(self, log_dir: str, tag: Optional[str] = None):
# Import wandb here to avoid startup import
import wandb

self.wandb = wandb
self.tag = tag
self.writer: Optional[SummaryWriter] = None
if enable_tb:
self.writer = SummaryWriter(log_dir, max_queue=1000)

def log(self, metrics: Dict[str, Any], step: int):
if self.writer is not None:
for k, v in metrics.items():
tag = k if self.tag is None else f"{self.tag}/{k}"
self.writer.add_scalar(tag, v, step)
self.wandb.init(
project="torchtitan",
dir=log_dir,
)
logger.info("WandB logging enabled")

def log(self, metrics: Dict[str, Any], step: int) -> None:
wandb_metrics = {
(k if self.tag is None else f"{self.tag}/{k}"): v
for k, v in metrics.items()
}
wandb_metrics["step"] = step
self.wandb.log(wandb_metrics)

def close(self):
if self.writer is not None:
self.writer.close()
def close(self) -> None:
if self.wandb.run is not None:
self.wandb.finish()


def _get_metrics_rank(parallel_dims: ParallelDims) -> int:
Expand All @@ -126,35 +162,69 @@ def _get_metrics_rank(parallel_dims: ParallelDims) -> int:
metrics_log_rank = (world_size // pp_size) * (pp_size - 1)
else:
metrics_log_rank = 0

return metrics_log_rank


def build_metric_logger(
job_config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None
):
) -> BaseLogger:
"""
parallel_dims is used to determine the rank to log metrics from if 'tb_config.rank_0_only=True'.
In that case, `_get_metrics_rank` will be used to calculate which rank acts as 'rank 0'. This is
intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline
parallelism is enabled, without forcing logging from all ranks to capture loss information.
Build an appropriate metric logger based on configuration.
"""
metrics_config = job_config.metrics

# Log initial config state
logger.debug(
f"Building logger with config: wandb={metrics_config.enable_wandb}, "
f"tensorboard={metrics_config.enable_tensorboard}"
)

# Check if any logging backend is enabled
has_logging_enabled = (
metrics_config.enable_tensorboard or metrics_config.enable_wandb
)

# Determine if this rank should log
should_log = has_logging_enabled
if metrics_config.rank_0_only and should_log:
metrics_rank = _get_metrics_rank(parallel_dims)
should_log = torch.distributed.get_rank() == metrics_rank

logger.debug(
f"Logging decision: has_logging_enabled={has_logging_enabled}, should_log={should_log}"
)

if not should_log:
logger.debug("Returning BaseLogger due to should_log=False")
return BaseLogger()

# Setup logging directory
dump_dir = job_config.job.dump_folder
tb_config = job_config.metrics
save_tb_folder = tb_config.save_tb_folder
# since we don't have run id, use current minute as the identifier
datetime_str = datetime.now().strftime("%Y%m%d-%H%M")
log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str)

enable_tb = tb_config.enable_tensorboard
if enable_tb:
logger.info(
f"Metrics logging active. Tensorboard logs will be saved at {log_dir}"
base_log_dir = os.path.join(
dump_dir, metrics_config.save_tb_folder, datetime.now().strftime("%Y%m%d-%H%M")
)

if not metrics_config.rank_0_only:
base_log_dir = os.path.join(
base_log_dir, f"rank_{torch.distributed.get_rank()}"
)
if tb_config.rank_0_only:
enable_tb = torch.distributed.get_rank() == _get_metrics_rank(parallel_dims)
else:
rank_str = f"rank_{torch.distributed.get_rank()}"
log_dir = os.path.join(log_dir, rank_str)

return MetricLogger(log_dir, tag, enable_tb)
# Create loggers in priority order
if metrics_config.enable_wandb:
logger.debug("Attempting to create WandB logger")
try:
return WandBLogger(base_log_dir, tag)
except Exception as e:
if "No module named 'wandb'" in str(e):
logger.error(
"Failed to create WandB logger: No module named 'wandb'. Please install it using 'pip install wandb'."
)
else:
logger.error(f"Failed to create WandB logger: {e}")

if metrics_config.enable_tensorboard:
logger.debug("Creating TensorBoard logger")
return TensorBoardLogger(base_log_dir, tag)

logger.debug("No loggers enabled, returning BaseLogger")
return BaseLogger()
1 change: 1 addition & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ log_freq = 1
enable_color_printing = true
enable_tensorboard = true
save_tb_folder = "tb"
enable_wandb = false

[model]
name = "llama3"
Expand Down

0 comments on commit 6a6f755

Please sign in to comment.