From 3ab9a9970881f77f222fa70f96bcaad34ac60b78 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Fri, 23 Feb 2024 12:39:43 -0800 Subject: [PATCH] make actual changes to recipes --- recipes/__init__.py | 6 +- .../configs/alpaca_llama2_full_finetune.yaml | 32 +- .../alpaca_llama2_full_finetune_hydra.yaml | 46 -- .../configs/alpaca_llama2_lora_finetune.yaml | 43 +- recipes/full_finetune.py | 139 +++--- recipes/full_finetune_hydra.py | 471 ------------------ recipes/lora_finetune.py | 164 +++--- recipes/params/full_finetune.py | 122 ----- recipes/params/lora_finetune.py | 141 ------ requirements.txt | 1 - tests/torchtune/datasets/test_get_dataset.py | 24 - tests/torchtune/models/test_get_model.py | 45 -- tests/torchtune/utils/test_argparse.py | 41 -- tests/torchtune/utils/test_metric_logging.py | 26 - torchtune/datasets/__init__.py | 22 +- torchtune/losses.py | 28 -- torchtune/models/__init__.py | 42 -- torchtune/modules/__init__.py | 82 --- torchtune/optim.py | 40 -- torchtune/utils/argparse.py | 72 --- torchtune/utils/metric_logging.py | 40 +- 21 files changed, 205 insertions(+), 1422 deletions(-) delete mode 100644 recipes/configs/alpaca_llama2_full_finetune_hydra.yaml delete mode 100644 recipes/full_finetune_hydra.py delete mode 100644 recipes/params/full_finetune.py delete mode 100644 recipes/params/lora_finetune.py delete mode 100644 tests/torchtune/datasets/test_get_dataset.py delete mode 100644 tests/torchtune/models/test_get_model.py delete mode 100644 tests/torchtune/utils/test_argparse.py delete mode 100644 torchtune/losses.py delete mode 100644 torchtune/optim.py delete mode 100644 torchtune/utils/argparse.py diff --git a/recipes/__init__.py b/recipes/__init__.py index f18679d85a..70c0a70060 100644 --- a/recipes/__init__.py +++ b/recipes/__init__.py @@ -5,7 +5,11 @@ # LICENSE file in the root directory of this source tree. -_RECIPE_LIST = ["full_finetune", "lora_finetune", "alpaca_generate", "full_finetune_hydra"] +_RECIPE_LIST = [ + "full_finetune", + "lora_finetune", + "alpaca_generate", +] _CONFIG_LISTS = { "full_finetune": ["alpaca_llama2_full_finetune"], "lora_finetune": ["alpaca_llama2_lora_finetune"], diff --git a/recipes/configs/alpaca_llama2_full_finetune.yaml b/recipes/configs/alpaca_llama2_full_finetune.yaml index 98fe30c7e9..fe7969951a 100644 --- a/recipes/configs/alpaca_llama2_full_finetune.yaml +++ b/recipes/configs/alpaca_llama2_full_finetune.yaml @@ -4,25 +4,43 @@ # tune --nnodes 1 --nproc_per_node 1 --config alpaca_llama2_full_finetune --override model_checkpoint= ... # Dataset and Dataloader -dataset: alpaca +dataset: + _target_: torchtune.datasets.AlpacaDataset + train_on_input: True + seed: null shuffle: True # Model Arguments -model: llama2_7b +model: + _target_: torchtune.models.llama2_7b + model_checkpoint: /tmp/llama2-7b -tokenizer: llama2_tokenizer -tokenizer_checkpoint: /tmp/tokenizer.model +tokenizer: + _target_: torchtune.models.llama2_tokenizer + path: /tmp/tokenizer.model # Fine-tuning arguments batch_size: 2 -lr: 2e-5 epochs: 3 -optimizer: SGD -loss: CrossEntropyLoss +optimizer: + _target_: torch.optim.SGD + lr: 2e-5 +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +log_every_n_steps: null +run_generation: null + +loss: + _target_: torch.nn.CrossEntropyLoss + output_dir: /tmp/alpaca-llama2-finetune device: cuda dtype: fp32 enable_fsdp: True enable_activation_checkpointing: True +cpu_offload: False resume_from_checkpoint: False +metric_logger: + _target_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} diff --git a/recipes/configs/alpaca_llama2_full_finetune_hydra.yaml b/recipes/configs/alpaca_llama2_full_finetune_hydra.yaml deleted file mode 100644 index fe7969951a..0000000000 --- a/recipes/configs/alpaca_llama2_full_finetune_hydra.yaml +++ /dev/null @@ -1,46 +0,0 @@ -# Runs the full_finetune.py recipe using FullFinetuneParams -# -# To launch, run the following command from root: -# tune --nnodes 1 --nproc_per_node 1 --config alpaca_llama2_full_finetune --override model_checkpoint= ... - -# Dataset and Dataloader -dataset: - _target_: torchtune.datasets.AlpacaDataset - train_on_input: True - -seed: null -shuffle: True - -# Model Arguments -model: - _target_: torchtune.models.llama2_7b - -model_checkpoint: /tmp/llama2-7b -tokenizer: - _target_: torchtune.models.llama2_tokenizer - path: /tmp/tokenizer.model - -# Fine-tuning arguments -batch_size: 2 -epochs: 3 -optimizer: - _target_: torch.optim.SGD - lr: 2e-5 -max_steps_per_epoch: null -gradient_accumulation_steps: 1 -log_every_n_steps: null -run_generation: null - -loss: - _target_: torch.nn.CrossEntropyLoss - -output_dir: /tmp/alpaca-llama2-finetune -device: cuda -dtype: fp32 -enable_fsdp: True -enable_activation_checkpointing: True -cpu_offload: False -resume_from_checkpoint: False -metric_logger: - _target_: torchtune.utils.metric_logging.DiskLogger - log_dir: ${output_dir} diff --git a/recipes/configs/alpaca_llama2_lora_finetune.yaml b/recipes/configs/alpaca_llama2_lora_finetune.yaml index 6316b86ba4..c32d43edfd 100644 --- a/recipes/configs/alpaca_llama2_lora_finetune.yaml +++ b/recipes/configs/alpaca_llama2_lora_finetune.yaml @@ -1,9 +1,11 @@ # Model Arguments -model: lora_llama2_7b +model: + _target_: lora_llama2_7b + lora_attn_modules: ['q_proj', 'v_proj'] + lora_rank: 8 + lora_alpha: 16 + model_checkpoint: /tmp/llama2-7b -lora_attn_modules: ['q_proj', 'v_proj'] -lora_rank: 8 -lora_alpha: 16 lora_checkpoint: null # Tokenizer @@ -11,26 +13,37 @@ tokenizer: llama2_tokenizer tokenizer_checkpoint: /tmp/tokenizer.model # Dataset and Sampler -dataset: alpaca -train_on_input: True +dataset: + _target_: torchtune.datasets.AlpacaDataset + train_on_input: True + use_clean: True shuffle: True batch_size: 2 # Optimizer and Scheduler -optimizer: AdamW -weight_decay: 0.01 -lr: 3e-4 -lr_scheduler: cosine_with_warmup -num_warmup_steps: 100 -loss: CrossEntropyLoss +optimizer: + _target_: AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: # TODO: this is a partial instantiation, make this more elegant + _target_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _target_: torch.nn.CrossEntropyLoss # Training epochs: 1 resume_from_checkpoint: False +# Logging +output_dir: /tmp/lora_finetune_output +metric_logger: + _target_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} + # Environment device: cuda dtype: fp32 - -# Logging -output_dir: /tmp/lora_finetune_output +enable_fsdp: True +enable_activation_checkpointing: True diff --git a/recipes/full_finetune.py b/recipes/full_finetune.py index c04858383b..55bb75c77e 100644 --- a/recipes/full_finetune.py +++ b/recipes/full_finetune.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import argparse import os import sys @@ -12,10 +11,12 @@ from typing import Any, Dict, Optional, Tuple from warnings import warn +import hydra + import torch +from omegaconf import DictConfig from recipes.interfaces import FTRecipeInterface -from recipes.params.full_finetune import FullFinetuneParams from torch import nn from torch.cuda.amp import GradScaler @@ -23,7 +24,7 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import datasets, models, modules, utils +from torchtune import modules, utils from torchtune.utils.constants import ( EPOCHS_KEY, MAX_STEPS_KEY, @@ -57,40 +58,63 @@ class FullFinetuneRecipe(FTRecipeInterface): - Training happens on CUDA (CPU training is not supported) - Checkpoints are ONLY saved at epoch boundaries. Mid-epoch checkpointing is NOT supported. - Datasets are Map-style and data fits in memory (not streamed). + + Args: + device (str): Device to use for training. Options are "cpu" and "cuda" + dtype (str): Data type to use for training. + seed (int): Random seed to use for training. + model (str): String specifying model architecture to fine-tune. See ``torchtune.models.get_model`` for options. + model_checkpoint (str): Local path to load model checkpoint from. + tokenizer (str): String specifying tokenizer to use. See ``torchtune.models.get_tokenizer`` for options. + tokenizer_checkpoint (str): Local path to load tokenizer checkpoint from. + dataset (str): String specifying dataset to use. See ``torchtune.datasets.get_dataset`` for options. + Currently, only predefined datasets in library are supported. + shuffle (bool): Whether to shuffle dataset. + batch_size (int): Batch size to use for training. + epochs (int): Number of epochs to train for. + optimizer (str): String specifying optimizer to use. See ``torchtune.optim.get_optimizer`` for options. + loss (str): String specifying loss function to use. See ``torchtune.losses.get_loss`` for options. + lr (float): Learning rate to use for optimizer. + activation_checkpointing (bool): Whether to use activation checkpointing. + output_dir (str): Local path to save checkpoints and logs to. + run_generation (int): Run eval on a prompt every ``run_generation`` steps. Set to 0 to disable. + max_steps_per_epoch (int): Maximum number of steps to take per epoch. + metric_logger_type (str): String specifying metric logger to use. See ``torchtune.utils.get_metric_logger`` + for options. + project (str): Project name to use for logging. Used by ``WandBLogger``. + resume_from_previous_checkpoint (bool): Whether to resume fine-tuning from a previous checkpoint. + cpu_offload (bool): Whether to offload model to CPU. + + Raises: + ValueError: If ``cpu_offload`` is ``True`` but ``device`` is not ``cuda`` and <= 1 GPUs. """ - def __init__(self, params: FullFinetuneParams) -> None: + def __init__(self, cfg: DictConfig) -> None: - self._device = utils.get_device(device=params.device) - self._dtype = utils.get_dtype(dtype=params.dtype) + self._device = utils.get_device(device=cfg.device) + self._dtype = utils.get_dtype(dtype=cfg.dtype) # logging attributes - self._output_dir = params.output_dir - self._metric_logger = utils.get_metric_logger( - metric_logger_type=params.metric_logger_type, - project=params.project, - log_dir=params.output_dir, - ) - self._log_every_n_steps = ( - params.log_every_n_steps if params.log_every_n_steps else 1 - ) + self._output_dir = cfg.output_dir + self._metric_logger = hydra.utils.instantiate(cfg.metric_logger) + self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1 # _is_rank_zero is used primarily for logging. In the future, the logger # should directly take care of this _, rank = utils.get_world_size_and_rank() self._is_rank_zero = rank == 0 - # Training params - self._resume_from_checkpoint = params.resume_from_checkpoint - self._enable_fsdp = params.enable_fsdp - self._gradient_accumulation_steps = params.gradient_accumulation_steps + # Training cfg + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._enable_fsdp = cfg.enable_fsdp + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests - self.seed = utils.set_seed(seed=params.seed) + self.seed = utils.set_seed(seed=cfg.seed) self.epochs_run = 0 - self.total_epochs = params.epochs - self.max_steps_per_epoch = params.max_steps_per_epoch + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch self.total_training_steps = 0 def load_checkpoint(self, ckpt_path: str): @@ -101,13 +125,13 @@ def load_checkpoint(self, ckpt_path: str): utils.validate_checkpoint(ckpt_dict, self._resume_from_checkpoint) return ckpt_dict - def setup(self, params: FullFinetuneParams) -> None: + def setup(self, cfg: DictConfig) -> None: """ Sets up the recipe state correctly. This includes setting recipe attributes based on the ``resume_from_checkpoint`` flag. """ - ckpt_dict = self.load_checkpoint(ckpt_path=params.model_checkpoint) + ckpt_dict = self.load_checkpoint(ckpt_path=cfg.model_checkpoint) # If we're resuming from checkpoint, the recipe's state should be updated before # initializing the training components. This ensures that the seed is correctly @@ -119,40 +143,38 @@ def setup(self, params: FullFinetuneParams) -> None: # should be called before ``_setup_optimizer`` since transforming the optimizer # state dict requires the model self._model = self._setup_model( - model=params.model, - enable_fsdp=params.enable_fsdp, - enable_activation_checkpointing=params.enable_activation_checkpointing, + cfg_model=cfg.model, + enable_fsdp=cfg.enable_fsdp, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, model_state_dict=ckpt_dict[MODEL_KEY], ) self._tokenizer = self._setup_tokenizer( - tokenizer=params.tokenizer, tokenizer_checkpoint=params.tokenizer_checkpoint + cfg_tokenizer=cfg.tokenizer, ) # _setup_optimizer should take in ckpt_dict only if training is resumed from # checkpoint. Transforming the opt state dict is handled by this method self._optimizer = self._setup_optimizer( - optimizer=params.optimizer, - lr=params.lr, + cfg_optimizer=cfg.optimizer, opt_state_dict=ckpt_dict[OPT_KEY] if self._resume_from_checkpoint else None, ) - self._loss_fn = self._setup_loss(loss=params.loss) + self._loss_fn = self._setup_loss(cfg_loss=cfg.loss) # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after both of these are initialized self._sampler, self._dataloader = self._setup_data( - dataset=params.dataset, - train_on_input=params.train_on_input, - shuffle=params.shuffle, - batch_size=params.batch_size, + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, ) # training setup self._autocast = utils.get_autocast(self._dtype, self._device) self._grad_scaler = None if self._dtype == torch.float16: - self._grad_scaler = utils.get_gradient_scaler(fsdp=params.enable_fsdp) + self._grad_scaler = utils.get_gradient_scaler(fsdp=cfg.enable_fsdp) else: self._grad_scaler = GradScaler(enabled=False) @@ -195,7 +217,7 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: def _setup_model( self, - model: str, + cfg_model: DictConfig, enable_fsdp: bool, enable_activation_checkpointing: bool, model_state_dict: Dict[str, Any], @@ -205,7 +227,9 @@ def _setup_model( ``enable_fsdp`` should always be ``True``. This is currently a configurable flag for running tests on CPUs. """ - model = models.get_model(model, device=self._device) + with self._device: + model = hydra.utils.instantiate(cfg_model) + model = ( utils.wrap_fsdp( model=model, @@ -229,27 +253,28 @@ def _setup_model( return model def _setup_tokenizer( - self, tokenizer: str, tokenizer_checkpoint: str + self, + cfg_tokenizer: DictConfig, ) -> modules.Tokenizer: """ Unlike ```setup_model```, this takes in the checkpoint and loads the sentencepiece tokenizer model. This is related to how the tokenizer is implemented and should change in a future iteration. """ - tokenizer = models.get_tokenizer(tokenizer, path=tokenizer_checkpoint) + tokenizer = hydra.utils.instantiate(cfg_tokenizer) if self._is_rank_zero: log.info("Tokenizer is initialized from file.") return tokenizer def _setup_optimizer( - self, optimizer: str, lr: float, opt_state_dict: Optional[Dict[str, Any]] = None + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None ) -> Optimizer: """ Set up the optimizer. This method also handles transforing the state dict for FSDP. """ - optimizer = modules.get_optimizer(optimizer, self._model, lr) + optimizer = hydra.utils.instantiate(cfg_optimizer, self._model.parameters()) if opt_state_dict: opt_state_dict = utils.transform_opt_state_dict( opt_state_dict, self._model, optimizer @@ -260,8 +285,8 @@ def _setup_optimizer( log.info("Optimizer is initialized.") return optimizer - def _setup_loss(self, loss: str) -> nn.Module: - loss_fn = modules.get_loss(loss) + def _setup_loss(self, cfg_loss: DictConfig) -> nn.Module: + loss_fn = hydra.utils.instantiate(cfg_loss) if self._is_rank_zero: log.info("Loss is initialized.") @@ -269,7 +294,10 @@ def _setup_loss(self, loss: str) -> nn.Module: return loss_fn def _setup_data( - self, dataset: str, shuffle: bool, batch_size: int, train_on_input: bool + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, ) -> Tuple[DistributedSampler, DataLoader]: """ All data related setup happens here. Currently this recipe only supports the @@ -277,11 +305,9 @@ def _setup_data( iterable datasets and streaming datasets are not supported. """ world_size, rank = utils.get_world_size_and_rank() - ds = datasets.get_dataset( - dataset, - split="train", + ds = hydra.utils.instantiate( + cfg_dataset, tokenizer=self._tokenizer, - train_on_input=train_on_input, ) sampler = DistributedSampler( ds, @@ -425,7 +451,8 @@ def cleanup(self) -> None: self._metric_logger.close() -def recipe_main() -> None: +@hydra.main(config_path="configs") +def recipe_main(cfg: DictConfig) -> None: """ Entry point for the recipe. @@ -434,19 +461,11 @@ def recipe_main() -> None: - Overwritten by Parameters specified in ``alpaca_llama2_full_finetune.yaml`` - Overwritten by arguments from the command-line using ``TuneArgumentParser`` """ - parser = utils.TuneArgumentParser( - description=FullFinetuneParams.__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - args, _ = parser.parse_known_args() - args = vars(args) - recipe_params = FullFinetuneParams(**args) - # Env variables set by torch run; only need to initialize process group init_process_group(backend="nccl") - recipe = FullFinetuneRecipe(params=recipe_params) - recipe.setup(params=recipe_params) + recipe = FullFinetuneRecipe(cfg=cfg) + recipe.setup(cfg=cfg) recipe.train() recipe.cleanup() diff --git a/recipes/full_finetune_hydra.py b/recipes/full_finetune_hydra.py deleted file mode 100644 index 94110f4263..0000000000 --- a/recipes/full_finetune_hydra.py +++ /dev/null @@ -1,471 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import sys - -from functools import partial -from typing import Any, Dict, Optional, Tuple -from warnings import warn - -import torch - -from recipes.interfaces import FTRecipeInterface - -from torch import nn -from torch.cuda.amp import GradScaler -from torch.distributed import init_process_group -from torch.optim import Optimizer -from torch.utils.data import DataLoader, DistributedSampler - -from torchtune import modules, utils -from torchtune.utils.constants import ( - EPOCHS_KEY, - MAX_STEPS_KEY, - MODEL_KEY, - OPT_KEY, - SEED_KEY, - TOTAL_EPOCHS_KEY, -) - -from tqdm import tqdm -import hydra -from omegaconf import DictConfig - - -log = utils.get_logger("DEBUG") - - -class FullFinetuneRecipe(FTRecipeInterface): - """ - Full finetuning recipe for dense transformer-based LLMs such as Llama2. - - This recipe supports: - - FSDP and activation checkpointing. This is enabled by default but can be - configured using the ``enable_fsdp`` and ``enable_activation_checkpointing`` flags. - - Mixed precision training - fp32, fp16 and bf16 are supported. - - Checkpointing of model weights, optimizer state and the recipe state (epoch and seed). - - Resuming from checkpoints saved using the ``save_checkpoint`` functionality. - - Logging to terminal. WandB and TensorBoard are currently not supported. - - Assumptions: - - Training is launched with the Tune CLI (recommended) which uses TorchRun under the - hood. Setting up the env variables is handled by TorchRun. - - Training happens on CUDA (CPU training is not supported) - - Checkpoints are ONLY saved at epoch boundaries. Mid-epoch checkpointing is NOT supported. - - Datasets are Map-style and data fits in memory (not streamed). - - Args: - device (str): Device to use for training. Options are "cpu" and "cuda" - dtype (str): Data type to use for training. - seed (int): Random seed to use for training. - model (str): String specifying model architecture to fine-tune. See ``torchtune.models.get_model`` for options. - model_checkpoint (str): Local path to load model checkpoint from. - tokenizer (str): String specifying tokenizer to use. See ``torchtune.models.get_tokenizer`` for options. - tokenizer_checkpoint (str): Local path to load tokenizer checkpoint from. - dataset (str): String specifying dataset to use. See ``torchtune.datasets.get_dataset`` for options. - Currently, only predefined datasets in library are supported. - shuffle (bool): Whether to shuffle dataset. - batch_size (int): Batch size to use for training. - epochs (int): Number of epochs to train for. - optimizer (str): String specifying optimizer to use. See ``torchtune.optim.get_optimizer`` for options. - loss (str): String specifying loss function to use. See ``torchtune.losses.get_loss`` for options. - lr (float): Learning rate to use for optimizer. - activation_checkpointing (bool): Whether to use activation checkpointing. - output_dir (str): Local path to save checkpoints and logs to. - run_generation (int): Run eval on a prompt every ``run_generation`` steps. Set to 0 to disable. - max_steps_per_epoch (int): Maximum number of steps to take per epoch. - metric_logger_type (str): String specifying metric logger to use. See ``torchtune.utils.get_metric_logger`` - for options. - project (str): Project name to use for logging. Used by ``WandBLogger``. - resume_from_previous_checkpoint (bool): Whether to resume fine-tuning from a previous checkpoint. - cpu_offload (bool): Whether to offload model to CPU. - - Raises: - ValueError: If ``cpu_offload`` is ``True`` but ``device`` is not ``cuda`` and <= 1 GPUs. - """ - - def __init__(self, cfg: DictConfig) -> None: - - self._device = utils.get_device(device=cfg.device) - self._dtype = utils.get_dtype(dtype=cfg.dtype) - - # logging attributes - self._output_dir = cfg.output_dir - self._metric_logger = hydra.utils.instantiate(cfg.metric_logger) - self._log_every_n_steps = ( - cfg.log_every_n_steps if cfg.log_every_n_steps else 1 - ) - - # _is_rank_zero is used primarily for logging. In the future, the logger - # should directly take care of this - _, rank = utils.get_world_size_and_rank() - self._is_rank_zero = rank == 0 - - # Training cfg - self._resume_from_checkpoint = cfg.resume_from_checkpoint - self._enable_fsdp = cfg.enable_fsdp - self._gradient_accumulation_steps = cfg.gradient_accumulation_steps - - # These are public properties which are updated by the checkpoint loader - # when ``resume_from_checkpoint`` is `True` or validated in tests - self.seed = utils.set_seed(seed=cfg.seed) - self.epochs_run = 0 - self.total_epochs = cfg.epochs - self.max_steps_per_epoch = cfg.max_steps_per_epoch - self.total_training_steps = 0 - - def load_checkpoint(self, ckpt_path: str): - """ - Extract the checkpoint state from file and validate. - """ - ckpt_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) - utils.validate_checkpoint(ckpt_dict, self._resume_from_checkpoint) - return ckpt_dict - - def setup(self, cfg: DictConfig) -> None: - """ - Sets up the recipe state correctly. This includes setting recipe attributes based - on the ``resume_from_checkpoint`` flag. - """ - - ckpt_dict = self.load_checkpoint(ckpt_path=cfg.model_checkpoint) - - # If we're resuming from checkpoint, the recipe's state should be updated before - # initializing the training components. This ensures that the seed is correctly - # propagated to the relevant components - if self._resume_from_checkpoint: - self._update_recipe_state(ckpt_dict) - - # ``_setup_model`` handles initialization and loading the state dict. This method - # should be called before ``_setup_optimizer`` since transforming the optimizer - # state dict requires the model - self._model = self._setup_model( - cfg_model=cfg.model, - enable_fsdp=cfg.enable_fsdp, - enable_activation_checkpointing=cfg.enable_activation_checkpointing, - model_state_dict=ckpt_dict[MODEL_KEY], - ) - - self._tokenizer = self._setup_tokenizer( - cfg_tokenizer=cfg.tokenizer, - ) - - # _setup_optimizer should take in ckpt_dict only if training is resumed from - # checkpoint. Transforming the opt state dict is handled by this method - self._optimizer = self._setup_optimizer( - cfg_optimizer=cfg.optimizer, - opt_state_dict=ckpt_dict[OPT_KEY] if self._resume_from_checkpoint else None, - ) - - self._loss_fn = self._setup_loss(cfg_loss=cfg.loss) - - # sampler and dataloader depend on the tokenizer and loss_fn and should be - # setup after both of these are initialized - self._sampler, self._dataloader = self._setup_data( - cfg_dataset=cfg.dataset, - shuffle=cfg.shuffle, - batch_size=cfg.batch_size, - ) - - # training setup - self._autocast = utils.get_autocast(self._dtype, self._device) - self._grad_scaler = None - if self._dtype == torch.float16: - self._grad_scaler = utils.get_gradient_scaler(fsdp=cfg.enable_fsdp) - else: - self._grad_scaler = GradScaler(enabled=False) - - # Finally update the recipe state which can only be correctly set after all of the - # other components have been initialized and updated. - # - # Number of training steps in each epoch depends on the number of batches produced - # by the dataloader, the max_steps_per_epoch param set by the user and the - # gradient_accumulation_steps param. This value is used for logging and tracking - # training state. The computation should happen after the dataloader has been setup - self._steps_per_epoch = ( - len(self._dataloader) // self._gradient_accumulation_steps - ) - if ( - self.max_steps_per_epoch is not None - and self.max_steps_per_epoch < self._steps_per_epoch - ): - self._steps_per_epoch = self.max_steps_per_epoch - self.total_training_steps = self.epochs_run * self._steps_per_epoch - - def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: - """ - Updates the recipe state from checkpoint. - """ - # If seed, total_epoch or max_steps_per_epoch don't match, - # warn the user and overwrite - if ( - self.seed != ckpt_dict[SEED_KEY] - or self.total_epochs != ckpt_dict[TOTAL_EPOCHS_KEY] - or self.max_steps_per_epoch != ckpt_dict[MAX_STEPS_KEY] - ): - warn( - message="""Configured value for seed, epochs or max_steps_per_epoch - does not match the value stored in checkpoint.""" - ) - self.seed = utils.set_seed(seed=ckpt_dict[SEED_KEY]) - self.epochs_run = ckpt_dict[EPOCHS_KEY] - self.total_epochs = ckpt_dict[TOTAL_EPOCHS_KEY] - self.max_steps_per_epoch = ckpt_dict[MAX_STEPS_KEY] - - def _setup_model( - self, - cfg_model: DictConfig, - enable_fsdp: bool, - enable_activation_checkpointing: bool, - model_state_dict: Dict[str, Any], - ) -> nn.Module: - """ - Set up the model including enabling FSDP and activation checkpointing. For this recipe, - ``enable_fsdp`` should always be ``True``. This is currently a configurable flag for - running tests on CPUs. - """ - with self._device: - model = hydra.utils.instantiate(cfg_model) - - model = ( - utils.wrap_fsdp( - model=model, - device=self._device, - dtype=self._dtype, - strategy="FULL_SHARD", - auto_wrap_policy={modules.TransformerDecoderLayer}, - ) - if enable_fsdp - else model - ) - if enable_activation_checkpointing: - utils.set_activation_checkpointing( - model, auto_wrap_policy={modules.TransformerDecoderLayer} - ) - - model.load_state_dict(model_state_dict) - - if self._is_rank_zero: - log.info("Model is initialized.") - return model - - def _setup_tokenizer( - self, cfg_tokenizer: DictConfig, - ) -> modules.Tokenizer: - """ - Unlike ```setup_model```, this takes in the checkpoint and loads the sentencepiece - tokenizer model. This is related to how the tokenizer is implemented and should - change in a future iteration. - """ - tokenizer = hydra.utils.instantiate(cfg_tokenizer) - - if self._is_rank_zero: - log.info("Tokenizer is initialized from file.") - return tokenizer - - def _setup_optimizer( - self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None - ) -> Optimizer: - """ - Set up the optimizer. This method also handles transforing the state dict - for FSDP. - """ - optimizer = hydra.utils.instantiate(cfg_optimizer, self._model.parameters()) - if opt_state_dict: - opt_state_dict = utils.transform_opt_state_dict( - opt_state_dict, self._model, optimizer - ) - optimizer.load_state_dict(opt_state_dict) - - if self._is_rank_zero: - log.info("Optimizer is initialized.") - return optimizer - - def _setup_loss(self, cfg_loss: DictConfig) -> nn.Module: - loss_fn = hydra.utils.instantiate(cfg_loss) - - if self._is_rank_zero: - log.info("Loss is initialized.") - - return loss_fn - - def _setup_data( - self, cfg_dataset: DictConfig, shuffle: bool, batch_size: int, - ) -> Tuple[DistributedSampler, DataLoader]: - """ - All data related setup happens here. Currently this recipe only supports the - DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, - iterable datasets and streaming datasets are not supported. - """ - world_size, rank = utils.get_world_size_and_rank() - ds = hydra.utils.instantiate( - cfg_dataset, - tokenizer=self._tokenizer, - ) - sampler = DistributedSampler( - ds, - num_replicas=world_size, - rank=rank, - shuffle=shuffle, - seed=0, - ) - dataloader = DataLoader( - dataset=ds, - batch_size=batch_size, - sampler=sampler, - collate_fn=partial( - utils.padded_collate, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, # TODO support loss without ignore_index - ), - ) - - if self._is_rank_zero: - log.info("Dataset and Sampler are initialized.") - - return sampler, dataloader - - def save_checkpoint(self, epoch: int) -> None: - """ - Checkpoint the relevant state of a recipe. - - This makes use of the `save_checkpoint` utility which is responsible for - writing the checkpoint dictionary to file. The contents of the dict are dictated - by whether training is complete or not. - - If training is ongoing, optimizer state, seed and epochs_run are saved along with the - model weights. - """ - os.makedirs(self._output_dir, exist_ok=True) - output_loc = f"{self._output_dir}/model_{epoch}.ckpt" - ckpt_dict = {MODEL_KEY: self._model} - - # if training is in-progress, checkpoint the optimizer state as well - if epoch + 1 < self.total_epochs: - ckpt_dict.update( - { - OPT_KEY: self._optimizer, - SEED_KEY: self.seed, - EPOCHS_KEY: self.epochs_run, - TOTAL_EPOCHS_KEY: self.total_epochs, - MAX_STEPS_KEY: self.max_steps_per_epoch, - } - ) - utils.save_checkpoint(ckpt_dict, output_loc) - - if self._is_rank_zero: - log.info( - f"Model checkpoint of size {os.path.getsize(output_loc) >> 20} MB saved to {output_loc}" - ) - - def _should_update_weights(self, curr_step: int) -> bool: - """ - Determines whether the weights should be updated on the current step or not. - True is returned either if we've accumulated gradients for enough steps or if this - is the last step in the epoch. - """ - should_update_weights = ( - curr_step + 1 - ) % self._gradient_accumulation_steps == 0 or ( - curr_step + 1 - ) == self._steps_per_epoch - return should_update_weights - - def train(self) -> None: - """ - The core training loop. Supports training on subsets of the dataset using the - ``max_steps_per_epoch``. - """ - _, rank = utils.get_world_size_and_rank() - - # zero out the gradients before starting training - self._optimizer.zero_grad() - - # self.epochs_run should be non-zero when we're resuming from a checkpoint - for curr_epoch in range(self.epochs_run, self.total_epochs): - - # Update the sampler to ensure data is correctly shuffled across epochs - # in case shuffle is True - self._sampler.set_epoch(curr_epoch) - - for idx, batch in enumerate( - pbar := tqdm(self._dataloader, disable=not (rank == 0)) - ): - if ( - self.max_steps_per_epoch is not None - and (idx // self._gradient_accumulation_steps) - == self.max_steps_per_epoch - ): - break - - input_ids, labels = batch - input_ids = input_ids.to(self._device) - labels = labels.to(self._device) - - with self._autocast: - logits = self._model(input_ids) - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - logits = logits.transpose(1, 2) - # Compute loss - loss = self._loss_fn(logits, labels) - - # Note: We're always logging the loss before normalizing it - # Check if this is the norm or not - pbar.set_description(f"{curr_epoch+1}|{idx+1}|Loss: {loss.item()}") - - if self.total_training_steps % self._log_every_n_steps == 0: - self._metric_logger.log_dict( - { - "loss": loss.item(), - "lr": self._optimizer.param_groups[0]["lr"], - "gpu_resources": torch.cuda.memory_allocated(), - }, - step=self.total_training_steps, - ) - - # Does loss normalization need to happen within autocast context? - loss = loss / self._gradient_accumulation_steps - self._grad_scaler.scale(loss).backward() - - if self._should_update_weights(idx): - self._grad_scaler.step(self._optimizer) - self._grad_scaler.update() - self._optimizer.zero_grad(set_to_none=True) - - # Update the number of steps when the weights are updated - self.total_training_steps += 1 - - self.epochs_run += 1 - self.save_checkpoint(epoch=curr_epoch) - - def cleanup(self) -> None: - self._metric_logger.close() - - -@hydra.main(config_path="configs") -def recipe_main(cfg: DictConfig) -> None: - """ - Entry point for the recipe. - - Configurable parameters are read in the following order: - - Parameters specified in ``FullFinetuneParams`` - - Overwritten by Parameters specified in ``alpaca_llama2_full_finetune.yaml`` - - Overwritten by arguments from the command-line using ``TuneArgumentParser`` - """ - # Env variables set by torch run; only need to initialize process group - init_process_group(backend="nccl") - - recipe = FullFinetuneRecipe(cfg=cfg) - recipe.setup(cfg=cfg) - recipe.train() - recipe.cleanup() - - -if __name__ == "__main__": - sys.exit(recipe_main()) diff --git a/recipes/lora_finetune.py b/recipes/lora_finetune.py index c82b339d12..bc84794813 100644 --- a/recipes/lora_finetune.py +++ b/recipes/lora_finetune.py @@ -4,31 +4,32 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import argparse import os import sys from functools import partial -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Optional, Tuple from warnings import warn +import hydra + import torch +from omegaconf import DictConfig from recipes.interfaces import FTRecipeInterface -from recipes.params.lora_finetune import LoRAFinetuneParams from torch import nn from torch.cuda.amp import GradScaler from torch.distributed import init_process_group from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import datasets, models, modules, utils -from torchtune.modules.peft.lora import reset_lora_params +from torchtune import modules, utils +from torchtune.modules.peft.lora import reset_lora_cfg from torchtune.modules.peft.peft_utils import ( - get_adapter_params, + get_adapter_cfg, lora_fsdp_init, lora_fsdp_wrap_policy, - set_trainable_params, + set_trainable_cfg, validate_state_dict_for_lora, ) from torchtune.utils.constants import ( @@ -39,7 +40,7 @@ SEED_KEY, TOTAL_EPOCHS_KEY, ) -from torchtune.utils.distributed import validate_no_meta_params +from torchtune.utils.distributed import validate_no_meta_cfg from tqdm import tqdm log = utils.get_logger("DEBUG") @@ -65,10 +66,10 @@ class LoRAFinetuneRecipe(FTRecipeInterface): """ - def __init__(self, params: LoRAFinetuneParams) -> None: + def __init__(self, cfg: DictConfig) -> None: - self._device = utils.get_device(device=params.device) - self._dtype = utils.get_dtype(dtype=params.dtype) + self._device = utils.get_device(device=cfg.device) + self._dtype = utils.get_dtype(dtype=cfg.dtype) # _is_rank_zero is used primarily for logging. In the future, the logger # should directly take care of this @@ -76,28 +77,22 @@ def __init__(self, params: LoRAFinetuneParams) -> None: self._is_rank_zero = rank == 0 # logging attributes - self._output_dir = params.output_dir - self._log_every_n_steps = ( - params.log_every_n_steps if params.log_every_n_steps else 1 - ) + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1 if self._is_rank_zero: - self._metric_logger = utils.get_metric_logger( - metric_logger_type=params.metric_logger_type, - project=params.project, - log_dir=params.output_dir, - ) + self._metric_logger = hydra.utils.instantiate(cfg.metric_logger) # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests - self.seed = utils.set_seed(seed=params.seed) + self.seed = utils.set_seed(seed=cfg.seed) self.epochs_run = 0 - self.total_epochs = params.epochs - self.max_steps_per_epoch = params.max_steps_per_epoch + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch self.total_training_steps = 0 - self._resume_from_checkpoint = params.resume_from_checkpoint + self._resume_from_checkpoint = cfg.resume_from_checkpoint - def setup(self, params: LoRAFinetuneParams) -> None: + def setup(self, cfg: DictConfig) -> None: """ Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. @@ -107,7 +102,7 @@ def setup(self, params: LoRAFinetuneParams) -> None: # This is because we only save LoRA weights during training, so only lora_checkpoint # will contain training state, while model_checkpoint contains model weights only. base_model_ckpt = self.load_checkpoint( - ckpt_path=params.model_checkpoint, resume_from_checkpoint=False + ckpt_path=cfg.model_checkpoint, resume_from_checkpoint=False ) # If we're resuming from checkpoint, the recipe's state should be updated before @@ -115,20 +110,17 @@ def setup(self, params: LoRAFinetuneParams) -> None: # propagated to the relevant components if self._resume_from_checkpoint: assert ( - params.lora_checkpoint is not None + cfg.lora_checkpoint is not None ), "Must pass lora_checkpoint when resuming training" lora_ckpt = self.load_checkpoint( - ckpt_path=params.lora_checkpoint, resume_from_checkpoint=True + ckpt_path=cfg.lora_checkpoint, resume_from_checkpoint=True ) self._update_recipe_state(lora_ckpt) self._model = self._setup_model( - model=params.model, - lora_attn_modules=params.lora_attn_modules, - lora_rank=params.lora_rank, - lora_alpha=params.lora_alpha, - enable_fsdp=params.enable_fsdp, - enable_activation_checkpointing=params.enable_activation_checkpointing, + cfg_model=cfg.model, + enable_fsdp=cfg.enable_fsdp, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, base_model_state_dict=base_model_ckpt[MODEL_KEY], lora_weights_state_dict=lora_ckpt[MODEL_KEY] if self._resume_from_checkpoint @@ -136,32 +128,28 @@ def setup(self, params: LoRAFinetuneParams) -> None: ) self._tokenizer = self._setup_tokenizer( - tokenizer=params.tokenizer, tokenizer_checkpoint=params.tokenizer_checkpoint + cfg_tokenizer=cfg.tokenizer, ) self._optimizer = self._setup_optimizer( - optimizer=params.optimizer, - lr=params.lr, - weight_decay=params.weight_decay, + cfg_optimizer=cfg.optimizer, opt_state_dict=lora_ckpt[OPT_KEY] if self._resume_from_checkpoint else None, ) - self._loss_fn = self._setup_loss(loss=params.loss) + self._loss_fn = self._setup_loss(cfg_loss=cfg.loss) # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after all of these are setup self._sampler, self._dataloader = self._setup_data( - dataset=params.dataset, - shuffle=params.shuffle, - batch_size=params.batch_size, - train_on_input=params.train_on_input, - use_clean=params.use_clean, + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, ) # training setup self._autocast = utils.get_autocast(self._dtype, self._device) if self._dtype == torch.float16: - self._grad_scaler = utils.get_gradient_scaler(fsdp=params.enable_fsdp) + self._grad_scaler = utils.get_gradient_scaler(fsdp=cfg.enable_fsdp) else: self._grad_scaler = GradScaler(enabled=False) @@ -182,8 +170,7 @@ def setup(self, params: LoRAFinetuneParams) -> None: # Learning rate scheduler can only be set up after number of steps # has been computed self._lr_scheduler = self._setup_lr_scheduler( - lr_scheduler=params.lr_scheduler, - num_warmup_steps=params.num_warmup_steps, + cfg_lr_scheduler=cfg.lr_scheduler, num_training_steps=self.total_epochs * steps_per_epoch, last_epoch=self.total_training_steps - 1, ) @@ -218,10 +205,7 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: def _setup_model( self, - model: str, - lora_attn_modules: List[str], - lora_rank: int, - lora_alpha: float, + cfg_model: DictConfig, enable_fsdp: bool, enable_activation_checkpointing: bool, base_model_state_dict: Dict[str, Any], @@ -230,19 +214,14 @@ def _setup_model( # LoRA recipe uses meta device for FSDP init to avoid peak memory reserved # during model init init_device = "meta" if enable_fsdp else self._device - model = models.get_model( - model, - device=init_device, - lora_attn_modules=lora_attn_modules, - lora_rank=lora_rank, - lora_alpha=lora_alpha, - ) + with init_device: + model = hydra.utils.instantiate(cfg_model) - reset_lora_params(model, device=self._device) + reset_lora_cfg(model, device=self._device) # Note: this needs to be set before wrapping with FSDP - self.adapter_params = get_adapter_params(model) - set_trainable_params(model, self.adapter_params) + self.adapter_cfg = get_adapter_cfg(model) + set_trainable_cfg(model, self.adapter_cfg) if enable_fsdp: model = utils.wrap_fsdp( @@ -256,8 +235,8 @@ def _setup_model( param_init_fn=partial(lora_fsdp_init, device=self._device), ) - # Ensure no params and buffers are on meta device - validate_no_meta_params(model) + # Ensure no cfg and buffers are on meta device + validate_no_meta_cfg(model) if enable_activation_checkpointing: utils.set_activation_checkpointing( @@ -265,7 +244,7 @@ def _setup_model( ) validate_state_dict_for_lora( - lora_modules=lora_attn_modules, + lora_modules=cfg_model.lora_attn_modules, full_model_state_dict_keys=model.state_dict().keys(), lora_state_dict_keys=lora_weights_state_dict.keys() if lora_weights_state_dict is not None @@ -281,30 +260,27 @@ def _setup_model( return model def _setup_tokenizer( - self, tokenizer: str, tokenizer_checkpoint: str + self, + cfg_tokenizer: DictConfig, ) -> modules.Tokenizer: """ Unlike ```setup_model```, this takes in the checkpoint and loads the sentencepiece tokenizer model. This is related to how the tokenizer is implemented and should change in a future iteration. """ - tokenizer = models.get_tokenizer(tokenizer, path=tokenizer_checkpoint) + tokenizer = hydra.utils.instantiate(cfg_tokenizer) if self._is_rank_zero: log.info("Tokenizer is initialized from file.") return tokenizer def _setup_optimizer( - self, - optimizer: str, - lr: float, - weight_decay: float, - opt_state_dict: Optional[Dict[str, Any]] = None, + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None ) -> Optimizer: - optimizer = modules.get_optimizer(optimizer, self._model, lr, weight_decay) + optimizer = hydra.utils.instantiate(cfg_optimizer, self._model.parameters()) if opt_state_dict: # Note: technically we should check _contains_fsdp for - # just the state dict of the adapter params, but should be equivalent + # just the state dict of the adapter cfg, but should be equivalent opt_state_dict = utils.transform_opt_state_dict( opt_state_dict, self._model, optimizer ) @@ -316,15 +292,12 @@ def _setup_optimizer( def _setup_lr_scheduler( self, - lr_scheduler: str, - num_warmup_steps: int, + cfg_lr_scheduler: DictConfig, num_training_steps: int, last_epoch: int, ) -> Optimizer: - lr_scheduler = modules.get_lr_scheduler( - lr_scheduler, - self._optimizer, - num_warmup_steps=num_warmup_steps, + lr_scheduler = hydra.utils.instantiate( + cfg_lr_scheduler, num_training_steps=num_training_steps, last_epoch=last_epoch, ) @@ -332,8 +305,8 @@ def _setup_lr_scheduler( log.info("Learning rate scheduler is initialized.") return lr_scheduler - def _setup_loss(self, loss: str) -> nn.Module: - loss_fn = modules.get_loss(loss) + def _setup_loss(self, cfg_loss: DictConfig) -> nn.Module: + loss_fn = hydra.utils.instantiate(cfg_loss) if self._is_rank_zero: log.info("Loss is initialized.") @@ -342,11 +315,9 @@ def _setup_loss(self, loss: str) -> nn.Module: def _setup_data( self, - dataset: str, + cfg_dataset: DictConfig, shuffle: bool, batch_size: int, - train_on_input: bool, - use_clean: bool, ) -> Tuple[DistributedSampler, DataLoader]: """ All data related setup happens here. Currently this recipe only supports the @@ -354,11 +325,9 @@ def _setup_data( iterable datasets and streaming datasets are not supported. """ world_size, rank = utils.get_world_size_and_rank() - ds = datasets.get_dataset( - dataset, - split="train", + ds = hydra.utils.instantiate( + cfg_dataset, tokenizer=self._tokenizer, - train_on_input=train_on_input, ) sampler = DistributedSampler( ds, @@ -403,7 +372,7 @@ def save_checkpoint(self, epoch: int) -> None: } ) utils.save_checkpoint( - ckpt_dict, output_loc, model_key_filter=lambda x: x in self.adapter_params + ckpt_dict, output_loc, model_key_filter=lambda x: x in self.adapter_cfg ) if self._is_rank_zero: @@ -476,28 +445,21 @@ def cleanup(self) -> None: self._metric_logger.close() -def recipe_main() -> None: +@hydra.main(config_path="configs") +def recipe_main(cfg: DictConfig) -> None: """ Entry point for the recipe. Configurable parameters are read in the following order: - - Parameters specified in ``LoRAFinetuneParams`` + - Parameters specified in ``LoRAFinetunecfg`` - Overwritten by Parameters specified in ``alpaca_llama2_lora_finetune.yaml`` - Overwritten by arguments from the command-line using ``TuneArgumentParser`` """ - parser = utils.TuneArgumentParser( - description=LoRAFinetuneParams.__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - args, _ = parser.parse_known_args() - args = vars(args) - recipe_params = LoRAFinetuneParams(**args) - # Env variables set by torch run; only need to initialize process group init_process_group(backend="nccl") - recipe = LoRAFinetuneRecipe(params=recipe_params) - recipe.setup(params=recipe_params) + recipe = LoRAFinetuneRecipe(cfg=cfg) + recipe.setup(cfg=cfg) recipe.train() recipe.cleanup() diff --git a/recipes/params/full_finetune.py b/recipes/params/full_finetune.py deleted file mode 100644 index 0e8ff1585f..0000000000 --- a/recipes/params/full_finetune.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass, fields -from typing import Optional - -from torchtune.datasets import ALL_DATASETS -from torchtune.models import ALL_MODELS, ALL_TOKENIZERS -from torchtune.utils.metric_logging import ALL_METRIC_LOGGERS -from torchtune.utils.precision import PRECISION_STR_TO_DTYPE - - -@dataclass -class FullFinetuneParams: - """Arguments for the finetune_llm recipe. - - Args: - device (str): Device to use for training. Options are "cpu" and "cuda" - dtype (str): Data type to use for training. - seed (int): Random seed to use for training. - model (str): String specifying model architecture to fine-tune. See ``torchtune.models.get_model`` for options. - model_checkpoint (str): Local path to load model checkpoint from. - tokenizer (str): String specifying tokenizer to use. See ``torchtune.models.get_tokenizer`` for options. - tokenizer_checkpoint (str): Local path to load tokenizer checkpoint from. - dataset (str): String specifying dataset to use. See ``torchtune.datasets.get_dataset`` for options. - Currently, only predefined datasets in library are supported. - shuffle (bool): Whether to shuffle dataset. - batch_size (int): Batch size to use for training. - epochs (int): Number of epochs to train for. - optimizer (str): String specifying optimizer to use. See ``torchtune.optim.get_optimizer`` for options. - loss (str): String specifying loss function to use. See ``torchtune.losses.get_loss`` for options. - lr (float): Learning rate to use for optimizer. - activation_checkpointing (bool): Whether to use activation checkpointing. - output_dir (str): Local path to save checkpoints and logs to. - run_generation (int): Run eval on a prompt every ``run_generation`` steps. Set to 0 to disable. - max_steps_per_epoch (int): Maximum number of steps to take per epoch. - metric_logger_type (str): String specifying metric logger to use. See ``torchtune.utils.get_metric_logger`` - for options. - project (str): Project name to use for logging. Used by ``WandBLogger``. - resume_from_previous_checkpoint (bool): Whether to resume fine-tuning from a previous checkpoint. - cpu_offload (bool): Whether to offload model to CPU. - - Raises: - ValueError: If ``cpu_offload`` is ``True`` but ``device`` is not ``cuda`` and <= 1 GPUs. - """ - - # Model - model: str = "" - model_checkpoint: str = "" - - # Tokenizer - tokenizer: str = "" - tokenizer_checkpoint: str = "" - - # Dataset and Sampler - dataset: str = "" - train_on_input: bool = True - shuffle: bool = True - batch_size: int = 2 - - # Optimizer and Scheduler - optimizer: str = "SGD" - lr: float = 2e-5 - loss: str = "CrossEntropyLoss" - gradient_accumulation_steps: int = 1 - - # Training - epochs: int = 3 - max_steps_per_epoch: Optional[int] = None - resume_from_checkpoint: bool = False - run_generation: Optional[int] = None - - # Distributed - cpu_offload: bool = False - enable_fsdp: bool = True - enable_activation_checkpointing: bool = True - - # Environment - device: str = "cuda" - dtype: str = "fp32" - seed: Optional[int] = None - - # Logging - output_dir: str = "/tmp/full_finetune_output" - metric_logger_type: str = "disk" - project: Optional[str] = None - log_every_n_steps: Optional[int] = None - - def __post_init__(self): - for param in fields(self): - if getattr(self, param.name) == "": - raise TypeError(f"{param.name} needs to be specified") - - if self.cpu_offload and self.device != "cuda": - raise ValueError( - "Cannot offload model to CPU if device is not cuda or <= 1 GPUs." - ) - if self.enable_fsdp and self.device == "cpu": - raise ValueError("FSDP is not supported on CPU.") - if self.model not in ALL_MODELS: - raise ValueError( - f"Model not recognized. Expected one of {ALL_MODELS}, received {self.model}." - ) - if self.tokenizer not in ALL_TOKENIZERS: - raise ValueError( - f"Tokenizer not recognized. Expected one of {ALL_TOKENIZERS}, received {self.tokenizer}." - ) - if self.dataset not in ALL_DATASETS: - raise ValueError( - f"Dataset not recognized. Expected one of {ALL_DATASETS}, received {self.dataset}." - ) - if self.metric_logger_type not in ALL_METRIC_LOGGERS: - raise ValueError( - f"Metric logger not recognized. Expected one of {ALL_METRIC_LOGGERS}, received {self.metric_logger_type}." - ) - if self.dtype not in PRECISION_STR_TO_DTYPE: - raise ValueError( - f"Dtype {self.dtype} must be one of {', '.join(PRECISION_STR_TO_DTYPE.keys())} for finetuning." - ) diff --git a/recipes/params/lora_finetune.py b/recipes/params/lora_finetune.py deleted file mode 100644 index e49f0273ae..0000000000 --- a/recipes/params/lora_finetune.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass, field, fields -from typing import List, Optional - -from torchtune.datasets import ALL_DATASETS -from torchtune.models import ALL_MODELS, ALL_TOKENIZERS -from torchtune.utils.metric_logging import ALL_METRIC_LOGGERS -from torchtune.utils.precision import PRECISION_STR_TO_DTYPE - - -@dataclass -class LoRAFinetuneParams: - """Arguments for the finetune_lora recipe. Note that LoRA is currently only supported - for attention modules (i.e. Q, K, V, output projections), and not for MLP layers. - - Args: - model (str): String specifying model architecture to fine-tune. See ``torchtune.models.get_model`` for options. - model_checkpoint (str): Local path to load model checkpoint from. - lora_attn_modules (List[str]): List of attention modules to use for LoRA. Supported values are - ["q_proj", "k_proj", "v_proj", "output_proj"]. - lora_rank (int): Rank of LoRA decompositions. - lora_alpha (float): Alpha parameter for LoRA. - lora_checkpoint (str): Local path to load LoRA weights from. - tokenizer (str): String specifying tokenizer to use. See ``torchtune.models.get_tokenizer`` for options. - tokenizer_checkpoint (str): Local path to load tokenizer checkpoint from. - dataset (str): String specifying dataset to use. See ``torchtune.datasets.get_dataset`` for options. - Currently, only predefined datasets in library are supported. - train_on_input (bool): Whether to train on the prompt in addition to the response. - use_clean (bool): Whether to use cleaned version of Alpaca dataset or not. - shuffle (bool): Whether to shuffle dataset. - batch_size (int): Batch size to use for training. - epochs (int): Number of epochs to train for. - optimizer (str): String specifying optimizer to use. See ``torchtune.optim.get_optimizer`` for options. - weight_decay (float): Weight decay to use for optimizer. - lr (float): Base learning rate rate to use for optimizer. - lr_scheduler (str): String specifying learning rate scheduler to use. See - ``torchtune.lr_schedulers.get_lr_scheduler`` for options. - num_warmup_steps (int): Number of warmup steps to use for learning rate scheduler. - loss (str): String specifying loss function to use. See ``torchtune.losses.get_loss`` for options. - epochs (int): Number of epochs to train for. - max_steps_per_epoch (int): Maximum number of steps to take per epoch. - resume_from_checkpoint (bool): Whether to resume fine-tuning from a previous checkpoint. - cpu_offload (bool): Whether to offload model to CPU. - enable_fsdp (bool): Whether to use FSDP. - enable_activation_checkpointing (bool): Whether to use activation checkpointing. - device (str): Device to use for training. Options are "cpu" and "cuda" - dtype (str): Data type to use for training. - seed (int): Random seed to use for training. - output_dir (str): Local path to save checkpoints and logs to. - metric_logger_type (str): String specifying metric logger to use. See ``torchtune.utils.get_metric_logger`` - for options. - project (str): Project name to use for logging. Used by ``WandBLogger``. - log_every_n_steps (int): How often to log metrics. - """ - - # Model - model: str = "" - model_checkpoint: str = "" - lora_attn_modules: List[str] = field(default_factory=list) - lora_rank: int = 8 - lora_alpha: float = 16 - lora_checkpoint: Optional[str] = None - - # Tokenizer - tokenizer: str = "" - tokenizer_checkpoint: str = "" - - # Dataset and Sampler - dataset: str = "" - train_on_input: bool = True - use_clean: bool = True - shuffle: bool = True - batch_size: int = 2 - - # Optimizer and Scheduler - optimizer: str = "AdamW" - weight_decay: float = 0.01 - lr: float = 3e-4 - lr_scheduler: str = "cosine_with_warmup" - num_warmup_steps: int = 100 - loss: str = "CrossEntropyLoss" - - # Training - epochs: int = 1 - max_steps_per_epoch: Optional[int] = None - resume_from_checkpoint: bool = False - - # Distributed - cpu_offload: bool = False - enable_fsdp: bool = True - enable_activation_checkpointing: bool = True - - # Environment - device: str = "cuda" - dtype: str = "fp32" - seed: Optional[int] = None - - # Logging - output_dir: str = "/tmp/lora_finetune_output" - metric_logger_type: str = "disk" - project: Optional[str] = None - log_every_n_steps: Optional[int] = None - - def __post_init__(self): - for param in fields(self): - if getattr(self, param.name) == "": - raise TypeError(f"{param.name} needs to be specified") - - if self.cpu_offload and self.device != "cuda": - raise ValueError( - "Cannot offload model to CPU if device is not cuda or <= 1 GPUs." - ) - if self.enable_fsdp and self.device == "cpu": - raise ValueError("FSDP is not supported on CPU.") - if self.model not in ALL_MODELS: - raise ValueError( - f"Model not recognized. Expected one of {ALL_MODELS}, received {self.model}." - ) - if self.tokenizer not in ALL_TOKENIZERS: - raise ValueError( - f"Tokenizer not recognized. Expected one of {ALL_TOKENIZERS}, received {self.tokenizer}." - ) - if self.dataset not in ALL_DATASETS: - raise ValueError( - f"Dataset not recognized. Expected one of {ALL_DATASETS}, received {self.dataset}." - ) - if self.metric_logger_type not in ALL_METRIC_LOGGERS: - raise ValueError( - f"Metric logger not recognized. Expected one of {ALL_METRIC_LOGGERS}, received {self.metric_logger_type}." - ) - if self.dtype not in PRECISION_STR_TO_DTYPE: - raise ValueError( - f"Dtype {self.dtype} must be one of {', '.join(PRECISION_STR_TO_DTYPE.keys())} for finetuning." - ) - if len(self.lora_attn_modules) == 0: - raise ValueError("Must specify at least one module to apply LoRA to") diff --git a/requirements.txt b/requirements.txt index 33149d58f3..7151149637 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,5 +8,4 @@ huggingface_hub==0.19.4 # Misc sentencepiece==0.1.99 tqdm==4.66.1 -omegaconf==2.3.0 hydra-core==1.3.0 diff --git a/tests/torchtune/datasets/test_get_dataset.py b/tests/torchtune/datasets/test_get_dataset.py deleted file mode 100644 index 6a04c3dfed..0000000000 --- a/tests/torchtune/datasets/test_get_dataset.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from torchtune import datasets - - -class TestDatasetGetter: - def test_get_dataset(self): - """ - Test getting a named dataset - """ - datasets.ALL_DATASETS["test"] = lambda x: x - dataset = datasets.get_dataset("test", x=1) - assert dataset == 1 - - def test_list_datasets(self): - """ - Test accuracy of dataset list - """ - dataset_names = datasets.list_datasets() - assert "test" in dataset_names diff --git a/tests/torchtune/models/test_get_model.py b/tests/torchtune/models/test_get_model.py deleted file mode 100644 index dd1c0504c5..0000000000 --- a/tests/torchtune/models/test_get_model.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from torchtune import models - - -class TestModelTokenizerGetter: - def test_get_model(self): - """ - Test getting a named model - """ - models.ALL_MODELS["test"] = lambda x: x - model = models.get_model("test", "cpu", x=1) - assert model == 1 - - def test_get_model_device(self): - models.ALL_MODELS["test"] = lambda x: x - model = models.get_model("test", device=torch.device("cpu"), x=1) - assert model == 1 - - def test_list_models(self): - """ - Test accuracy of model list - """ - model_names = models.list_models() - assert "test" in model_names - - def test_get_tokenizer(self): - """ - Test getting a named tokenizer - """ - models.ALL_TOKENIZERS["test"] = lambda x: x - tokenizer = models.get_tokenizer("test", x=1) - assert tokenizer == 1 - - def test_list_tokenizer(self): - """ - Test accuracy of tokenizer list - """ - tokenizer_names = models.list_tokenizers() - assert "test" in tokenizer_names diff --git a/tests/torchtune/utils/test_argparse.py b/tests/torchtune/utils/test_argparse.py deleted file mode 100644 index 3acf06eea8..0000000000 --- a/tests/torchtune/utils/test_argparse.py +++ /dev/null @@ -1,41 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from unittest import mock - -import pytest - -from torchtune.utils import TuneArgumentParser - -_CONFIG = {"a": 1, "b": 2} - - -class TestArgParse: - @pytest.fixture - def parser(self): - parser = TuneArgumentParser("Test parser") - return parser - - @mock.patch("torchtune.utils.argparse.OmegaConf.load", return_value=_CONFIG) - def test_parse_args(self, mock_load, parser): - """ - Test that the parser can load a config and override parameters provided on CLI. - The actual load is mocked to return the test config above. - """ - args = parser.parse_args(["--config", "test.yaml", "--override", "b=3", "c=4"]) - assert args.a == 1, f"a == {args.a} not 1 as set in the config." - assert args.b == 3, f"b == {args.b} not 3 as set in the command args." - assert args.c == 4, f"c == {args.c} not 4 as set in the command args." - assert len(vars(args).keys() - {"a", "b", "c"}) == 0, "Extra args found." - - def test_required_argument(self, parser): - """ - Test that the parser does not allow required arguments to be added - """ - with pytest.raises(AssertionError): - parser.add_argument("--d", required=True, type=int, default=0) diff --git a/tests/torchtune/utils/test_metric_logging.py b/tests/torchtune/utils/test_metric_logging.py index 667c396d04..40cad238f7 100644 --- a/tests/torchtune/utils/test_metric_logging.py +++ b/tests/torchtune/utils/test_metric_logging.py @@ -16,38 +16,12 @@ from torchtune.utils.metric_logging import ( DiskLogger, - get_metric_logger, - list_metric_loggers, StdoutLogger, TensorBoardLogger, WandBLogger, ) -class TestMetricLogger: - def test_list_metric_loggers(self) -> None: - assert set(list_metric_loggers()) == { - "disk", - "stdout", - "tensorboard", - "wandb", - } - - def test_get_metric_logger(self) -> None: - fake_kwargs = { - "log_dir": "/tmp/output", - "project": "test-project", - "extra_key": "bananas", - } - assert isinstance(get_metric_logger("disk", **fake_kwargs), DiskLogger) - assert isinstance(get_metric_logger("stdout", **fake_kwargs), StdoutLogger) - assert isinstance( - get_metric_logger("tensorboard", **fake_kwargs), TensorBoardLogger - ) - with patch("wandb.init") as wandb_init: - assert isinstance(get_metric_logger("wandb", **fake_kwargs), WandBLogger) - - class TestDiskLogger: def test_log(self) -> None: with tempfile.TemporaryDirectory() as log_dir: diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index c31407859f..2190f91354 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -4,24 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torch.utils.data import Dataset - from .alpaca import AlpacaDataset from .slimorca import SlimOrcaDataset -ALL_DATASETS = {"alpaca": AlpacaDataset, "slimorca": SlimOrcaDataset} - - -def get_dataset(name: str, **kwargs) -> Dataset: - """Get known supported datasets by name""" - if name in ALL_DATASETS: - return ALL_DATASETS[name](**kwargs) - else: - raise ValueError( - f"Dataset not recognized. Expected one of {ALL_DATASETS}, received {name}" - ) - - -def list_datasets(): - """List of availabe datasets supported by `get_dataset`""" - return list(ALL_DATASETS) +__all__ = [ + "AlpacaDataset", + "SlimOrcaDataset", +] diff --git a/torchtune/losses.py b/torchtune/losses.py deleted file mode 100644 index b4290b095f..0000000000 --- a/torchtune/losses.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from torch import nn - - -def get_loss(loss: str) -> nn.Module: - """Returns a loss function from torch.nn. - - Args: - loss (str): name of the loss function. - - Returns: - nn.Module: loss function. - - Raises: - ValueError: if the loss is not a valid loss from torch.nn. - """ - try: - return getattr(nn, loss)() - except AttributeError as e: - raise ValueError(f"{loss} is not a valid loss from torch.nn") from e - - -# TODO convert to folder when we support llm specific losses diff --git a/torchtune/models/__init__.py b/torchtune/models/__init__.py index 2e8e3ae472..b5d1e8fb50 100644 --- a/torchtune/models/__init__.py +++ b/torchtune/models/__init__.py @@ -4,49 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable, Union - -import torch -from torch.nn import Module - -from torchtune.utils import get_device - from .llama2 import llama2_7b, llama2_tokenizer from .lora_llama2 import lora_llama2, lora_llama2_7b __all__ = ["llama2_7b", "llama2_tokenizer", "lora_llama2", "lora_llama2_7b"] - -ALL_MODELS = {"llama2_7b": llama2_7b, "lora_llama2_7b": lora_llama2_7b} -ALL_TOKENIZERS = {"llama2_tokenizer": llama2_tokenizer} - - -def get_model(name: str, device: Union[str, torch.device], **kwargs) -> Module: - """Get known supported models by name""" - if name in ALL_MODELS: - with get_device(device): - model = ALL_MODELS[name](**kwargs) - return model - else: - raise ValueError( - f"Model not recognized. Expected one of {ALL_MODELS}, received {name}" - ) - - -def get_tokenizer(name: str, **kwargs) -> Callable: - """Get known supported tokenizers by name""" - if name in ALL_TOKENIZERS: - return ALL_TOKENIZERS[name](**kwargs) - else: - raise ValueError( - f"Tokenizer not recognized. Expected one of {ALL_TOKENIZERS}, received {name}" - ) - - -def list_models(): - """List of availabe models supported by `get_model`""" - return list(ALL_MODELS) - - -def list_tokenizers(): - """List of availabe tokenizers supported by `get_tokenizer`""" - return list(ALL_TOKENIZERS) diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index ac011ccd3a..2603a3a060 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -4,12 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import torch - -from torch import nn -from torch.optim.lr_scheduler import LRScheduler -from torch.optim.optimizer import Optimizer - from .attention import CausalSelfAttention # noqa from .feed_forward import FeedForward # noqa from .kv_cache import KVCache # noqa @@ -30,79 +24,3 @@ "TransformerDecoder", "TransformerDecoderLayer", ] - - -def get_loss(loss: str) -> nn.Module: - """Returns a loss function from torch.nn. - - Args: - loss (str): name of the loss function. - - Returns: - nn.Module: loss function. - - Raises: - ValueError: if the loss is not a valid loss from torch.nn. - """ - try: - return getattr(nn, loss)() - except AttributeError as e: - raise ValueError(f"{loss} is not a valid loss from torch.nn") from e - - -def get_optimizer( - optimizer: str, model: torch.nn.Module, lr: float, weight_decay: float = 0.0 -) -> Optimizer: - """Returns an optimizer function from torch.optim. - - Args: - optimizer (str): name of the optimizer. - model (torch.nn.Module): model to optimize. - lr (float): learning rate. - weight_decay (float): weight decay for optimizer. Default is 0.0. - - Returns: - Optimizer: optimizer function. - - Raises: - ValueError: if the optimizer is not a valid optimizer from torch.optim. - """ - try: - trainable_params = [p for n, p in model.named_parameters() if p.requires_grad] - return getattr(torch.optim, optimizer)( - trainable_params, lr=lr, weight_decay=weight_decay - ) - except AttributeError as e: - raise ValueError( - f"{optimizer} is not a valid optimizer from torch.optim" - ) from e - - -ALL_LR_SCHEDULERS = {"cosine_with_warmup": get_cosine_schedule_with_warmup} - - -def get_lr_scheduler( - lr_scheduler: str, optimizer: torch.optim.Optimizer, **kwargs -) -> LRScheduler: - """Returns an optimizer function from torch.optim. - - Args: - lr_scheduler (str): name of the learning rate scheduler. - optimizer (torch.optim.Optimizer): optimizer. - **kwargs: additional arguments to pass to the learning rate scheduler. - - Returns: - LRScheduler: learning rate scheduler. - - Raises: - ValueError: if the lr scheduler is not a valid optimizer from torch.optim. - """ - try: - if lr_scheduler in ALL_LR_SCHEDULERS: - return ALL_LR_SCHEDULERS[lr_scheduler](optimizer, **kwargs) - else: - getattr(torch.optim.lr_scheduler, lr_scheduler)(optimizer, **kwargs) - except AttributeError as e: - raise ValueError( - f"{lr_scheduler} is not a valid learning rate scheduler from torch.optim.lr_scheduler or torchtune" - ) from e diff --git a/torchtune/optim.py b/torchtune/optim.py deleted file mode 100644 index 65f89eafc2..0000000000 --- a/torchtune/optim.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import torch -from torch.optim.optimizer import Optimizer - - -def get_optimizer( - optimizer: str, model: torch.nn.Module, lr: float, weight_decay: float = 0.0 -) -> Optimizer: - """Returns an optimizer function from torch.optim. - - Args: - optimizer (str): name of the optimizer. - model (torch.nn.Module): model to optimize. - lr (float): learning rate. - weight_decay (float): weight decay for optimizer. Default is 0.0. - - Returns: - Optimizer: optimizer function. - - Raises: - ValueError: if the optimizer is not a valid optimizer from torch.optim. - """ - try: - trainable_params = [p for n, p in model.named_parameters() if p.requires_grad] - return getattr(torch.optim, optimizer)( - trainable_params, lr=lr, weight_decay=weight_decay - ) - except AttributeError as e: - raise ValueError( - f"{optimizer} is not a valid optimizer from torch.optim" - ) from e - - -# TODO convert to folder when we support tuning specific optimizers diff --git a/torchtune/utils/argparse.py b/torchtune/utils/argparse.py deleted file mode 100644 index 170856ece7..0000000000 --- a/torchtune/utils/argparse.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -from argparse import Action, Namespace -from typing import List, Tuple - -from omegaconf import OmegaConf - - -class TuneArgumentParser(argparse.ArgumentParser): - """ - TuneArgumentParser is a helpful utility subclass of the argparse ArgumentParser that - adds a builtin argument "config". The config argument takes a file path to a yaml file - and will load in argument defaults from the yaml file. The yaml file must only contain - argument names and their values and nothing more, it does not have to include all of the - arguments. These values will be treated as defaults and can still be overridden from the - command line. Everything else works the same as the base ArgumentParser and you should - consult the docs for more info. - - https://docs.python.org/3/library/argparse.html - - *Note: This class does not support setting "required" arguments.* - *Note: This class uses "config" as a builtin argument so it is not available to use* - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - super().add_argument( - "--config", type=str, help="Path/name of a yaml file with recipe args" - ) - super().add_argument( - "--override", - type=str, - nargs="+", - help="Override config parameters with format KEY=VALUE", - ) - - def parse_known_args(self, *args, **kwargs) -> Tuple[Namespace, List[str]]: - """This acts the same as the base parse_known_args but will first load in defaults from - from the config yaml file if it is provided. The command line args will always take - precident over the values in the config file. All other parsing method, such as parse_args, - internally call this method so they will inherit this property too. For more info see - the docs for the base method. - - https://docs.python.org/3/library/argparse.html#the-parse-args-method - """ - namespace, _ = super().parse_known_args(*args, **kwargs) - if namespace.config is not None: - config = OmegaConf.load(namespace.config) - assert "config" not in config, "Cannot use 'config' within a config file" - self.set_defaults(**config) - if namespace.override is not None: - cli_config = OmegaConf.from_dotlist(namespace.override) - assert "config" not in config, "Cannot use 'override' within CLI arguments" - self.set_defaults(**cli_config) - namespace, unknown_args = super().parse_known_args(*args, **kwargs) - del namespace.config - del namespace.override - return namespace, unknown_args - - def add_argument(self, *args, **kwargs) -> Action: - """This calls the base method but throws an error if the required flag is set or the name used is config. - For more info on the method see the docs for the base method. - - https://docs.python.org/3/library/argparse.html#the-add-argument-method - """ - assert not kwargs.get("required", False), "Required not supported" - return super().add_argument(*args, **kwargs) diff --git a/torchtune/utils/metric_logging.py b/torchtune/utils/metric_logging.py index f0a25a568b..9c6999ba08 100644 --- a/torchtune/utils/metric_logging.py +++ b/torchtune/utils/metric_logging.py @@ -8,7 +8,7 @@ import time from pathlib import Path -from typing import Dict, List, Mapping, Optional, Union +from typing import Mapping, Optional, Union from numpy import ndarray from torch import Tensor @@ -239,41 +239,3 @@ def close(self) -> None: if self._writer: self._writer.close() self._writer = None - - -ALL_METRIC_LOGGERS: Dict[str, "MetricLoggerInterface"] = { - "wandb": WandBLogger, - "tensorboard": TensorBoardLogger, - "stdout": StdoutLogger, - "disk": DiskLogger, -} - - -def list_metric_loggers() -> List[str]: - """List available metric loggers. - - Returns: - List[str]: list of available metric loggers - """ - return list(ALL_METRIC_LOGGERS.keys()) - - -def get_metric_logger(metric_logger_type: str, **kwargs) -> "MetricLoggerInterface": - """Get a metric logger based on provided arguments. - - Args: - metric_logger_type (str): name of the metric logger, options are "wandb", "tensorboard", "stdout", "disk". - **kwargs: additional arguments to pass to the metric logger - - Raises: - ValueError: If ``metric_logger`` str is unknown. - - Returns: - MetricLoggerInterface: metric logger - """ - if metric_logger_type not in ALL_METRIC_LOGGERS: - raise ValueError( - f"Metric logger not recognized. Expected one of {list_metric_loggers}, received {metric_logger_type}." - ) - - return ALL_METRIC_LOGGERS[metric_logger_type](**kwargs)