diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index 2321627e..b7e3d935 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -6,3 +6,4 @@ sentencepiece tiktoken blobfile tabulate +wandb diff --git a/README.md b/README.md index 99d4e99d..c87f939b 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ You may want to see how the model is defined or how parallelism techniques are a 6. DDP and HSDP 7. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) 8. Learning rate scheduler, meta-init, (optional) fused RMSNorm kernel -9. Loss, GPU memory, throughput (tokens/sec), and MFU displayed and logged via [TensorBoard](#tensorboard) +9. Loss, GPU memory, throughput (tokens/sec), and MFU displayed and logged via [Tensorboard or Weights & Biases](/docs/metrics.md) 10. Debugging tools including CPU/GPU profiling, [memory profiling](docs/memory_profiler.md), [Flight Recorder](#debugging), etc. 11. All options easily configured via [toml files](train_configs/) @@ -73,7 +73,7 @@ We report our [Performance](docs/performance.md) verified on 64/128 GPUs. git clone https://github.com/pytorch/torchtitan cd torchtitan pip install -r requirements.txt -pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118 +pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall # or cu118 ``` ### Downloading a tokenizer @@ -99,26 +99,6 @@ Llama 3 8B model locally on 8 GPUs CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh ``` - -## TensorBoard - -To visualize TensorBoard metrics of models trained on a remote server via a local web browser: - -1. Make sure `metrics.enable_tensorboard` option is set to true in model training (either from a .toml file or from CLI). - -2. Set up SSH tunneling, by running the following from local CLI -``` -ssh -L 6006:127.0.0.1:6006 [username]@[hostname] -``` - -3. Inside the SSH tunnel that logged into the remote server, go to the torchtitan repo, and start the TensorBoard backend -``` -tensorboard --logdir=./outputs/tb -``` - -4. In the local web browser, go to the URL it provides OR to http://localhost:6006/. - - ## Multi-Node Training For training on ParallelCluster/Slurm type configurations, you can use the `multinode_trainer.slurm` file to submit your sbatch job. diff --git a/docs/metrics.md b/docs/metrics.md new file mode 100644 index 00000000..568e3a87 --- /dev/null +++ b/docs/metrics.md @@ -0,0 +1,36 @@ +# Metrics + +We support automatically collecting metrics such as +1. High level system metrics such as MFU, average loss, max loss and words per second along with some +2. Memory metrics to measure max VRAM consumption and the number of OOMs +3. Timing metrics to measure data loading bottlenecks + +Those metrics can then be visualized in either a TensorBoard or WandDB dashboard + +## TensorBoard + +To visualize TensorBoard metrics of models trained on a remote server via a local web browser: + +1. Make sure `metrics.enable_tensorboard` option is set to true in model training (either from a .toml file or from CLI). + +2. Set up SSH tunneling, by running the following from local CLI +``` +ssh -L 6006:127.0.0.1:6006 [username]@[hostname] +``` + +3. Inside the SSH tunnel that logged into the remote server, go to the torchtitan repo, and start the TensorBoard backend +``` +tensorboard --logdir=./outputs/tb +``` + +4. In the local web browser, go to the URL it provides OR to http://localhost:6006/. + +## Weights and Biases + +Weights and Biases will automatically send metrics to a remote server if you login with `wandb login` + +So all you need to do is make sure that `metrics.enable_wandb` is enabled + +For an example you can inspect [debug_model.toml](../train_configs/debug_model.toml) + +Note that if both W&B and Tensorboard are enabled then we will prioritize W&B. diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 814bd80f..e7bca6f1 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -121,15 +121,16 @@ def __init__(self): help="How often to log metrics to TensorBoard, in iterations", ) self.parser.add_argument( - "--metrics.enable_color_printing", - default=False, + "--metrics.enable_tensorboard", action="store_true", - help="Whether to enable color printing", + default=False, + help="Whether to log metrics to TensorBoard", ) self.parser.add_argument( - "--metrics.enable_tensorboard", + "--metrics.enable_color_printing", action="store_true", - help="Whether to log metrics to TensorBoard", + default=True, + help="Whether to enable color printing in logs", ) self.parser.add_argument( "--metrics.save_tb_folder", @@ -139,14 +140,20 @@ def __init__(self): ) self.parser.add_argument( "--metrics.rank_0_only", - default=True, action="store_true", + default=True, help=""" Whether to save TensorBoard metrics only for rank 0 or for all ranks. When pipeline_parallel_degree is > 1, this option uses the 0th rank of the last stage pipeline group, which is the only stage that computes loss metrics. """, ) + self.parser.add_argument( + "--metrics.enable_wandb", + action="store_true", + default=False, + help="Whether to log metrics to Weights & Biases", + ) # model configs self.parser.add_argument( diff --git a/torchtitan/metrics.py b/torchtitan/metrics.py index 525a4bf8..6409dcd6 100644 --- a/torchtitan/metrics.py +++ b/torchtitan/metrics.py @@ -16,7 +16,6 @@ from torchtitan.parallelisms import ParallelDims from torchtitan.utils import device_module, device_type - # named tuple for passing device memory stats for logging DeviceMemStats = namedtuple( "DeviceMemStats", @@ -90,29 +89,66 @@ def reset_peak_stats(self): def build_device_memory_monitor(): device_memory_monitor = DeviceMemoryMonitor(device_type) logger.info( - f"{device_type.upper()} capacity: {device_memory_monitor.device_name} ({device_memory_monitor.device_index}) " + f"{device_type.upper()} capacity: {device_memory_monitor.device_name}" f"with {device_memory_monitor.device_capacity_gib:.2f}GiB memory" ) - return device_memory_monitor -class MetricLogger: - def __init__(self, log_dir, tag, enable_tb): +class BaseLogger: + """Logger that does nothing, used when logging is disabled.""" + + def log(self, metrics: Dict[str, Any], step: int) -> None: + pass + + def close(self) -> None: + pass + + +class TensorBoardLogger(BaseLogger): + """Logger implementation for TensorBoard.""" + + def __init__(self, log_dir: str, tag: Optional[str] = None): + self.tag = tag + self.writer = SummaryWriter(log_dir, max_queue=1000) + logger.info(f"TensorBoard logging enabled. Logs will be saved at {log_dir}") + + def log(self, metrics: Dict[str, Any], step: int) -> None: + for k, v in metrics.items(): + tag = k if self.tag is None else f"{self.tag}/{k}" + self.writer.add_scalar(tag, v, step) + + def close(self) -> None: + self.writer.close() + + +class WandBLogger(BaseLogger): + """Logger implementation for Weights & Biases.""" + + def __init__(self, log_dir: str, tag: Optional[str] = None): + # Import wandb here to avoid startup import + import wandb + + self.wandb = wandb self.tag = tag - self.writer: Optional[SummaryWriter] = None - if enable_tb: - self.writer = SummaryWriter(log_dir, max_queue=1000) - def log(self, metrics: Dict[str, Any], step: int): - if self.writer is not None: - for k, v in metrics.items(): - tag = k if self.tag is None else f"{self.tag}/{k}" - self.writer.add_scalar(tag, v, step) + self.wandb.init( + project="torchtitan", + dir=log_dir, + ) + logger.info("WandB logging enabled") + + def log(self, metrics: Dict[str, Any], step: int) -> None: + wandb_metrics = { + (k if self.tag is None else f"{self.tag}/{k}"): v + for k, v in metrics.items() + } + wandb_metrics["step"] = step + self.wandb.log(wandb_metrics) - def close(self): - if self.writer is not None: - self.writer.close() + def close(self) -> None: + if self.wandb.run is not None: + self.wandb.finish() def _get_metrics_rank(parallel_dims: ParallelDims) -> int: @@ -126,35 +162,69 @@ def _get_metrics_rank(parallel_dims: ParallelDims) -> int: metrics_log_rank = (world_size // pp_size) * (pp_size - 1) else: metrics_log_rank = 0 - return metrics_log_rank def build_metric_logger( job_config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None -): +) -> BaseLogger: """ - parallel_dims is used to determine the rank to log metrics from if 'tb_config.rank_0_only=True'. - In that case, `_get_metrics_rank` will be used to calculate which rank acts as 'rank 0'. This is - intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline - parallelism is enabled, without forcing logging from all ranks to capture loss information. + Build an appropriate metric logger based on configuration. """ + metrics_config = job_config.metrics + + # Log initial config state + logger.debug( + f"Building logger with config: wandb={metrics_config.enable_wandb}, " + f"tensorboard={metrics_config.enable_tensorboard}" + ) + + # Check if any logging backend is enabled + has_logging_enabled = ( + metrics_config.enable_tensorboard or metrics_config.enable_wandb + ) + + # Determine if this rank should log + should_log = has_logging_enabled + if metrics_config.rank_0_only and should_log: + metrics_rank = _get_metrics_rank(parallel_dims) + should_log = torch.distributed.get_rank() == metrics_rank + + logger.debug( + f"Logging decision: has_logging_enabled={has_logging_enabled}, should_log={should_log}" + ) + + if not should_log: + logger.debug("Returning BaseLogger due to should_log=False") + return BaseLogger() + + # Setup logging directory dump_dir = job_config.job.dump_folder - tb_config = job_config.metrics - save_tb_folder = tb_config.save_tb_folder - # since we don't have run id, use current minute as the identifier - datetime_str = datetime.now().strftime("%Y%m%d-%H%M") - log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str) - - enable_tb = tb_config.enable_tensorboard - if enable_tb: - logger.info( - f"Metrics logging active. Tensorboard logs will be saved at {log_dir}" + base_log_dir = os.path.join( + dump_dir, metrics_config.save_tb_folder, datetime.now().strftime("%Y%m%d-%H%M") + ) + + if not metrics_config.rank_0_only: + base_log_dir = os.path.join( + base_log_dir, f"rank_{torch.distributed.get_rank()}" ) - if tb_config.rank_0_only: - enable_tb = torch.distributed.get_rank() == _get_metrics_rank(parallel_dims) - else: - rank_str = f"rank_{torch.distributed.get_rank()}" - log_dir = os.path.join(log_dir, rank_str) - return MetricLogger(log_dir, tag, enable_tb) + # Create loggers in priority order + if metrics_config.enable_wandb: + logger.debug("Attempting to create WandB logger") + try: + return WandBLogger(base_log_dir, tag) + except Exception as e: + if "No module named 'wandb'" in str(e): + logger.error( + "Failed to create WandB logger: No module named 'wandb'. Please install it using 'pip install wandb'." + ) + else: + logger.error(f"Failed to create WandB logger: {e}") + + if metrics_config.enable_tensorboard: + logger.debug("Creating TensorBoard logger") + return TensorBoardLogger(base_log_dir, tag) + + logger.debug("No loggers enabled, returning BaseLogger") + return BaseLogger() diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index da3bc45e..f681cdba 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -17,6 +17,7 @@ log_freq = 1 enable_color_printing = true enable_tensorboard = true save_tb_folder = "tb" +enable_wandb = false [model] name = "llama3"