diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index 2321627e..69e9d89d 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -1,8 +1,9 @@ torchdata >= 0.8.0 datasets >= 2.21.0 tomli >= 1.1.0 ; python_version < "3.11" -tensorboard +aim sentencepiece tiktoken blobfile tabulate +transformers \ No newline at end of file diff --git a/README.md b/README.md index c425e65e..9271b1d3 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes 3. Selective layer and operator activation checkpointing 4. Distributed checkpointing 5. 2 datasets pre-configured (45K - 144M) -6. GPU usage, MFU, tokens per second and more displayed via TensorBoard +6. GPU usage, MFU, tokens per second and more displayed via Aim 6. Learning rate scheduler, meta init, Optional Fused RMSNorm 7. All options easily configured via [toml files](train_configs/) 8. [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly into [`torchtune`](https://github.com/pytorch/torchtune) for fine tuning @@ -87,25 +87,6 @@ CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh ``` -## TensorBoard - -To visualize TensorBoard metrics of models trained on a remote server via a local web browser: - -1. Make sure `metrics.enable_tensorboard` option is set to true in model training (either from a .toml file or from CLI). - -2. Set up SSH tunneling, by running the following from local CLI -``` -ssh -L 6006:127.0.0.1:6006 [username]@[hostname] -``` - -3. Inside the SSH tunnel that logged into the remote server, go to the torchtitan repo, and start the TensorBoard backend -``` -tensorboard --logdir=./outputs/tb -``` - -4. In the local web browser, go to the URL it provides OR to http://localhost:6006/. - - ## Multi-Node Training For training on ParallelCluster/Slurm type configurations, you can use the `multinode_trainer.slurm` file to submit your sbatch job. diff --git a/pyproject.toml b/pyproject.toml index a5c1b72f..816952b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dev = [ "pre-commit", "pytest", "pytest-cov", - "tensorboard", + "aim", ] [tool.setuptools.dynamic] diff --git a/test/test_job_config.py b/test/test_job_config.py index e4ef04ba..2250eacd 100644 --- a/test/test_job_config.py +++ b/test/test_job_config.py @@ -14,12 +14,10 @@ class TestJobConfig: def test_command_line_args(self): config = JobConfig() config.parse_args([]) - assert config.training.steps == 10000 def test_job_config_file(self): config = JobConfig() config.parse_args(["--job.config_file", "./train_configs/debug_model.toml"]) - assert config.training.steps == 10 def test_job_file_does_not_exist(self): with pytest.raises(FileNotFoundError): @@ -30,7 +28,6 @@ def test_empty_config_file(self): with tempfile.NamedTemporaryFile() as fp: config = JobConfig() config.parse_args(["--job.config_file", fp.name]) - assert config.job.description def test_job_config_file_cmd_overrides(self): config = JobConfig() @@ -42,7 +39,6 @@ def test_job_config_file_cmd_overrides(self): "/tmp/test_tt/", ] ) - assert config.job.dump_folder == "/tmp/test_tt/" def test_print_help(self): config = JobConfig() diff --git a/torchtitan/aim.py b/torchtitan/aim.py new file mode 100644 index 00000000..8da24fbf --- /dev/null +++ b/torchtitan/aim.py @@ -0,0 +1,106 @@ +import os +from typing import Any, Dict, Optional + +from aim.ext.resource.configs import DEFAULT_SYSTEM_TRACKING_INT +from aim.sdk.repo import Repo +from aim.sdk.run import Run +from aim.sdk.utils import clean_repo_path, get_aim_repo_name + + +class AimLogger(): + def __init__( + self, + repo: Optional[str] = None, + experiment: Optional[str] = None, + system_tracking_interval: Optional[int] = DEFAULT_SYSTEM_TRACKING_INT, + log_system_params: Optional[bool] = True, + capture_terminal_logs: Optional[bool] = True, + run_name: Optional[str] = None, + run_hash: Optional[str] = None, + train_metric_prefix: Optional[str] = 'train_', + val_metric_prefix: Optional[str] = 'val_', + test_metric_prefix: Optional[str] = 'test_', + ): + super().__init__() + + self._experiment_name = experiment + self._run_name = run_name + self._repo_path = repo + + self._system_tracking_interval = system_tracking_interval + self._log_system_params = log_system_params + self._capture_terminal_logs = capture_terminal_logs + + self._run = None + self._run_hash = run_hash + + self._train_metric_prefix = train_metric_prefix + self._val_metric_prefix = val_metric_prefix + self._test_metric_prefix = test_metric_prefix + + @property + def experiment(self) -> Run: + if self._run is None: + if self._run_hash: + self._run = Run( + self._run_hash, + repo=self._repo_path, + system_tracking_interval=self._system_tracking_interval, + capture_terminal_logs=self._capture_terminal_logs, + force_resume=True, + ) + else: + self._run = Run( + repo=self._repo_path, + experiment=self._experiment_name, + system_tracking_interval=self._system_tracking_interval, + log_system_params=self._log_system_params, + capture_terminal_logs=self._capture_terminal_logs, + ) + self._run_hash = self._run.hash + if self._run_name is not None: + self._run.name = self._run_name + return self._run + + def log_hyperparams(self, params: Dict[str, Any]): + for key, value in params.items(): + self.experiment.set(('hparams', key), value, strict=False) + + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): + + metric_items: Dict[str:Any] = {k: v for k, v in metrics.items()} # for modifications to metric_items without affecting the original metrics + for k, v in metric_items.items(): + name = k + context = {} + if self._train_metric_prefix and name.startswith(self._train_metric_prefix): + name = name[len(self._train_metric_prefix) :] + context['subset'] = 'train' + elif self._test_metric_prefix and name.startswith(self._test_metric_prefix): + name = name[len(self._test_metric_prefix) :] + context['subset'] = 'test' + elif self._val_metric_prefix and name.startswith(self._val_metric_prefix): + name = name[len(self._val_metric_prefix) :] + context['subset'] = 'val' + self.experiment.track(v, name=name, step=step, context=context) + + def finalize(self, status: str = '') -> None: + if self._run: + self._run.close() + del self._run + self._run = None + + def __del__(self): + self.finalize() + + @property + def save_dir(self) -> str: + repo_path = clean_repo_path(self._repo_path) or Repo.default_repo_path() + return os.path.join(repo_path, get_aim_repo_name()) + + @property + def name(self) -> str: + return self._experiment_name + + @property + def version(self) -> str: + return self.experiment.hash \ No newline at end of file diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index b2bcfa17..ee326733 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -48,6 +48,8 @@ class TrainState(Stateful): step: int = 0 global_avg_losses: List[float] = field(default_factory=list) global_max_losses: List[float] = field(default_factory=list) + global_avg_perplexities: List[float] = field(default_factory=list) + global_max_perplexities: List[float] = field(default_factory=list) log_steps: List[int] = field(default_factory=list) def state_dict(self) -> Dict[str, Any]: diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 510005f3..e3d6f171 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -18,6 +18,8 @@ from torchtitan.logging import logger +from typing import Optional + TORCH_DTYPE_MAP = { "float16": torch.float16, "float32": torch.float32, @@ -118,7 +120,7 @@ def __init__(self): "--metrics.log_freq", type=int, default=10, - help="How often to log metrics to TensorBoard, in iterations", + help="How often to log metrics to aim, in iterations", ) self.parser.add_argument( "--metrics.enable_color_printing", @@ -127,22 +129,22 @@ def __init__(self): help="Whether to enable color printing", ) self.parser.add_argument( - "--metrics.enable_tensorboard", + "--metrics.enable_aim", action="store_true", - help="Whether to log metrics to TensorBoard", + help="Whether to log metrics to aim", ) self.parser.add_argument( - "--metrics.save_tb_folder", + "--metrics.save_aim_folder", type=str, - default="tb", - help="Folder to dump TensorBoard states", + default="aim", + help="Folder to dump Aim states", ) self.parser.add_argument( "--metrics.rank_0_only", default=True, action="store_true", help=""" - Whether to save TensorBoard metrics only for rank 0 or for all ranks. + Whether to save Aim metrics only for rank 0 or for all ranks. When pipeline_parallel_degree is > 1, this option uses the 0th rank of the last stage pipeline group, which is the only stage that computes loss metrics. """, @@ -546,7 +548,21 @@ def __init__(self): action="store_true", ) + self.parser.add_argument( + "--metrics.aim_hash", + type=Optional[str], + default=None, + help="The hash of the aim run to continue with", + ) + + self.parser.add_argument( + "--metrics.aim_experiment_name", + type=Optional[str], + default=None, + ) def parse_args(self, args_list: list = sys.argv[1:]): + self.args_list = args_list + args, cmd_args = self.parse_args_from_command_line(args_list) config_file = getattr(args, "job.config_file", None) # build up a two level dict diff --git a/torchtitan/metrics.py b/torchtitan/metrics.py index 3742115b..d3a61842 100644 --- a/torchtitan/metrics.py +++ b/torchtitan/metrics.py @@ -10,10 +10,10 @@ from typing import Any, Dict, Optional import torch -from torch.utils.tensorboard import SummaryWriter from torchtitan.config_manager import JobConfig from torchtitan.logging import logger from torchtitan.parallelisms import ParallelDims +from torchtitan.aim import AimLogger # named tuple for passing GPU memory stats for logging GPUMemStats = namedtuple( @@ -94,21 +94,27 @@ def build_gpu_memory_monitor(): class MetricLogger: - def __init__(self, log_dir, tag, enable_tb): - self.tag = tag - self.writer: Optional[SummaryWriter] = None - if enable_tb: - self.writer = SummaryWriter(log_dir, max_queue=1000) + def __init__(self, hash, experiment_name, log_dir, save_aim_folder, enable_aim): + self.writer: Optional[AimLogger] = None + if enable_aim: + if hash is not None: + self.writer = AimLogger(save_aim_folder, run_hash=hash) + elif experiment_name is not None: + self.writer = AimLogger(save_aim_folder, experiment=experiment_name) + else: + self.writer = AimLogger(save_aim_folder) def log(self, metrics: Dict[str, Any], step: int): if self.writer is not None: - for k, v in metrics.items(): - tag = k if self.tag is None else f"{self.tag}/{k}" - self.writer.add_scalar(tag, v, step) + self.writer.log_metrics(metrics, step) def close(self): if self.writer is not None: - self.writer.close() + self.writer.finalize() + + def log_hparams(self, config): + if self.writer is not None: + self.writer.experiment['hparams'] = config def _get_metrics_rank(parallel_dims: ParallelDims) -> int: @@ -122,30 +128,26 @@ def _get_metrics_rank(parallel_dims: ParallelDims) -> int: def build_metric_logger( - job_config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None + job_config: JobConfig, parallel_dims: ParallelDims ): """ - parallel_dims is used to determine the rank to log metrics from if 'tb_config.rank_0_only=True'. + parallel_dims is used to determine the rank to log metrics from if 'aim_config.rank_0_only=True'. In that case, `_get_metrics_rank` will be used to calculate which rank acts as 'rank 0'. This is intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline parallelism is enabled, without forcing logging from all ranks to capture loss information. """ dump_dir = job_config.job.dump_folder - tb_config = job_config.metrics - save_tb_folder = tb_config.save_tb_folder + aim_config = job_config.metrics + save_aim_folder = aim_config.save_aim_folder # since we don't have run id, use current minute as the identifier datetime_str = datetime.now().strftime("%Y%m%d-%H%M") - log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str) + log_dir = os.path.join(dump_dir, datetime_str) - enable_tb = tb_config.enable_tensorboard - if enable_tb: + enable_aim = aim_config.enable_aim + if enable_aim: logger.info( - f"Metrics logging active. Tensorboard logs will be saved at {log_dir}" + f"Metrics logging active. Aim logs will be saved at /{save_aim_folder}" ) - if tb_config.rank_0_only: - enable_tb = torch.distributed.get_rank() == _get_metrics_rank(parallel_dims) - else: - rank_str = f"rank_{torch.distributed.get_rank()}" - log_dir = os.path.join(log_dir, rank_str) + enable_aim = torch.distributed.get_rank() == _get_metrics_rank(parallel_dims) + return MetricLogger(job_config.metrics.aim_hash, job_config.metrics.aim_experiment_name, log_dir, save_aim_folder, enable_aim) - return MetricLogger(log_dir, tag, enable_tb) diff --git a/train.py b/train.py index 84a90ccc..73661a2c 100644 --- a/train.py +++ b/train.py @@ -201,8 +201,11 @@ def loss_fn(pred, labels): checkpoint_loaded = checkpoint.load() metric_logger = build_metric_logger(job_config, parallel_dims) + args, cmd_args = job_config.parse_args_from_command_line(job_config.args_list) + job_config_dict = job_config._args_to_two_level_dict(args) + metric_logger.log_hparams(job_config_dict) - # plot losses loaded from checkpoint (if any) to TensorBoard + # plot losses loaded from checkpoint (if any) to Aim # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq if train_state.step > 0: @@ -291,19 +294,31 @@ def loss_fn(pred, labels): or train_state.step % job_config.metrics.log_freq == 0 ): losses = [loss.item() for loss in losses_since_last_log] + + perplexities = [2 ** loss.item() for loss in losses_since_last_log] + avg_loss, max_loss = sum(losses) / len(losses), max(losses) + avg_perplexity, max_perplexity = sum(perplexities) / len(perplexities), max(perplexities) + if parallel_dims.dp_enabled: global_avg_loss, global_max_loss = ( utils.dist_mean(avg_loss, dp_mesh), utils.dist_max(max_loss, dp_mesh), ) + global_avg_perplexity, global_max_perplexity = ( + utils.dist_mean(avg_perplexity, dp_mesh), + utils.dist_max(max_perplexity, dp_mesh), + ) else: global_avg_loss, global_max_loss = avg_loss, max_loss + global_avg_perplexity, global_max_perplexity = avg_perplexity, max_perplexity # update train state train_state.log_steps.append(train_state.step) train_state.global_avg_losses.append(global_avg_loss) train_state.global_max_losses.append(global_max_loss) + train_state.global_avg_perplexities.append(global_avg_perplexity) + train_state.global_max_perplexities.append(global_max_perplexity) time_delta = time.perf_counter() - time_last_log @@ -325,6 +340,8 @@ def loss_fn(pred, labels): metrics = { "loss_metrics/global_avg_loss": global_avg_loss, "loss_metrics/global_max_loss": global_max_loss, + "loss_metrics/global_avg_perplexity": global_avg_perplexity, + "loss_metrics/global_max_perplexity": global_max_perplexity, "wps": wps, "mfu(%)": mfu, "time_metrics/end_to_end(s)": time_end_to_end, diff --git a/train_configs/chemlactica_125m.toml b/train_configs/chemlactica_125m.toml index 6c8b73f7..16136611 100644 --- a/train_configs/chemlactica_125m.toml +++ b/train_configs/chemlactica_125m.toml @@ -15,8 +15,10 @@ save_memory_snapshot_folder = "memory_snapshot" [metrics] log_freq = 1 enable_color_printing = true -enable_tensorboard = true -save_tb_folder = "tb" +enable_aim = true +save_aim_folder = "aim" +#aim_hash = "c6b4d8b340f74287b82ef928" +#aim_experiment_name = "hello" [model] name = "opt" diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 9cc9b52f..7db5da11 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -15,8 +15,10 @@ save_memory_snapshot_folder = "memory_snapshot" [metrics] log_freq = 1 enable_color_printing = true -enable_tensorboard = true -save_tb_folder = "tb" +# enable_aim = true +# save_aim_folder = "aim" +#aim_hash = "1d56ec7bed87438684a8da6b" +#aim_experiment_name = "hello" [model] name = "llama3" @@ -36,7 +38,7 @@ gradient_accumulation_steps = 1 seq_len = 2048 warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping -steps = 10 +steps = 30 data_parallel_degree = -1 tensor_parallel_degree = 1 compile = false @@ -47,8 +49,10 @@ pipeline_parallel_degree = 1 enable_async_tensor_parallel = false [checkpoint] -enable_checkpoint = false -folder = "checkpoint" +enable_checkpoint = true +save_folder = "checkpoint" +load_folder = "checkpoint" +create_seed_checkpoint = false interval_type = "steps" interval = 5 model_weights_only = false diff --git a/train_configs/galactica_125m_hf_to_titan.toml b/train_configs/galactica_125m_hf_to_titan.toml index 1318d4cf..f6e0a314 100644 --- a/train_configs/galactica_125m_hf_to_titan.toml +++ b/train_configs/galactica_125m_hf_to_titan.toml @@ -15,8 +15,10 @@ save_memory_snapshot_folder = "memory_snapshot" [metrics] log_freq = 1 enable_color_printing = true -enable_tensorboard = true -save_tb_folder = "tb" +enable_aim = true +save_aim_folder = "aim" +#aim_hash = "c6b4d8b340f74287b82ef928" +#aim_experiment_name = "hello" [model] name = "opt" diff --git a/train_configs/llama2_13b.toml b/train_configs/llama2_13b.toml index df2f6bb3..c1f83e10 100644 --- a/train_configs/llama2_13b.toml +++ b/train_configs/llama2_13b.toml @@ -12,8 +12,8 @@ profile_freq = 100 [metrics] log_freq = 10 -enable_tensorboard = true -save_tb_folder = "tb" +enable_aim = true +save_aim_folder = "aim" [model] name = "llama2" diff --git a/train_configs/llama2_70b.toml b/train_configs/llama2_70b.toml index 354ebe11..f6967d8c 100644 --- a/train_configs/llama2_70b.toml +++ b/train_configs/llama2_70b.toml @@ -12,8 +12,8 @@ profile_freq = 100 [metrics] log_freq = 10 -enable_tensorboard = true -save_tb_folder = "tb" +enable_aim = true +save_aim_folder = "aim" [model] name = "llama2" diff --git a/train_configs/llama2_7b.toml b/train_configs/llama2_7b.toml index e2b0e78d..8ff10d34 100644 --- a/train_configs/llama2_7b.toml +++ b/train_configs/llama2_7b.toml @@ -11,8 +11,8 @@ profile_freq = 100 [metrics] log_freq = 10 -enable_tensorboard = true -save_tb_folder = "tb" +enable_aim = true +save_aim_folder = "aim" [model] name = "llama2" diff --git a/train_configs/llama3_405b.toml b/train_configs/llama3_405b.toml index 5dca66a5..11f81207 100644 --- a/train_configs/llama3_405b.toml +++ b/train_configs/llama3_405b.toml @@ -12,8 +12,8 @@ profile_freq = 100 [metrics] log_freq = 10 -enable_tensorboard = true -save_tb_folder = "tb" +enable_aim = true +save_aim_folder = "aim" [model] name = "llama3" diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index 470149a5..787f5a3a 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -12,8 +12,8 @@ profile_freq = 100 [metrics] log_freq = 10 -enable_tensorboard = true -save_tb_folder = "tb" +enable_aim = true +save_aim_folder = "aim" [model] name = "llama3" diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 3d0c5160..e32b4137 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -12,8 +12,8 @@ profile_freq = 100 [metrics] log_freq = 10 -enable_tensorboard = true -save_tb_folder = "tb" +enable_aim = true +save_aim_folder = "aim" [model] name = "llama3"