From 760d7214d86623f4d7dc6c4632dde0dfd2771eec Mon Sep 17 00:00:00 2001 From: Rafi Ayub <33648637+RdoubleA@users.noreply.github.com> Date: Thu, 29 Feb 2024 16:37:37 -0800 Subject: [PATCH] [RFC][Break BC] Custom instantiate API to simplify our config system (#406) --- docs/source/api_ref_config.rst | 14 + docs/source/api_ref_utilities.rst | 1 - docs/source/examples/configs.rst | 251 +++++++++++------- docs/source/examples/finetune_llm.rst | 26 +- .../examples/first_finetune_tutorial.rst | 23 +- docs/source/examples/lora_finetune.rst | 9 +- docs/source/examples/recipe_deepdive.rst | 25 ++ docs/source/index.rst | 1 + recipes/__init__.py | 12 +- recipes/alpaca_generate.py | 80 ++---- .../configs/alpaca_llama2_full_finetune.yaml | 42 ++- recipes/configs/alpaca_llama2_generate.yaml | 19 ++ .../configs/alpaca_llama2_lora_finetune.yaml | 55 ++-- recipes/full_finetune.py | 140 ++++------ recipes/lora_finetune.py | 176 +++++------- recipes/params/full_finetune.py | 122 --------- recipes/params/lora_finetune.py | 141 ---------- recipes/tests/configs/test_configs.py | 22 +- recipes/tests/test_alpaca_generate.py | 18 +- recipes/tests/test_full_finetune.py | 151 ++++++----- recipes/tests/test_lora_finetune.py | 19 +- recipes/tests/test_params.py | 75 ------ recipes/tests/utils.py | 30 ++- requirements.txt | 2 +- tests/torchtune/config/test_instantiate.py | 76 ++++++ tests/torchtune/config/test_parse.py | 32 +++ tests/torchtune/config/test_utils.py | 28 ++ .../torchtune/datasets/test_alpaca_dataset.py | 15 +- tests/torchtune/datasets/test_get_dataset.py | 24 -- .../datasets/test_slimorca_dataset.py | 23 +- tests/torchtune/models/test_get_model.py | 45 ---- tests/torchtune/utils/test_metric_logging.py | 26 -- torchtune/config/__init__.py | 13 + torchtune/config/_instantiate.py | 104 ++++++++ torchtune/config/_parse.py | 55 ++++ torchtune/config/_utils.py | 85 ++++++ torchtune/datasets/__init__.py | 22 +- torchtune/losses.py | 28 -- torchtune/models/__init__.py | 44 +-- torchtune/modules/__init__.py | 82 ------ torchtune/optim.py | 40 --- torchtune/utils/__init__.py | 3 - torchtune/utils/metric_logging.py | 40 +-- 43 files changed, 1018 insertions(+), 1221 deletions(-) create mode 100644 docs/source/api_ref_config.rst create mode 100644 recipes/configs/alpaca_llama2_generate.yaml delete mode 100644 recipes/params/full_finetune.py delete mode 100644 recipes/params/lora_finetune.py delete mode 100644 recipes/tests/test_params.py create mode 100644 tests/torchtune/config/test_instantiate.py create mode 100644 tests/torchtune/config/test_parse.py create mode 100644 tests/torchtune/config/test_utils.py delete mode 100644 tests/torchtune/datasets/test_get_dataset.py delete mode 100644 tests/torchtune/models/test_get_model.py create mode 100644 torchtune/config/__init__.py create mode 100644 torchtune/config/_instantiate.py create mode 100644 torchtune/config/_parse.py create mode 100644 torchtune/config/_utils.py delete mode 100644 torchtune/losses.py delete mode 100644 torchtune/optim.py diff --git a/docs/source/api_ref_config.rst b/docs/source/api_ref_config.rst new file mode 100644 index 0000000000..80dd45d743 --- /dev/null +++ b/docs/source/api_ref_config.rst @@ -0,0 +1,14 @@ +.. _config: + +================== +torchtune.config +================== + +.. currentmodule:: torchtune.config + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + instantiate + parse diff --git a/docs/source/api_ref_utilities.rst b/docs/source/api_ref_utilities.rst index 58701f5b79..db4d220734 100644 --- a/docs/source/api_ref_utilities.rst +++ b/docs/source/api_ref_utilities.rst @@ -51,7 +51,6 @@ Metric Logging :toctree: generated/ :nosignatures: - metric_logging.get_metric_logger metric_logging.WandBLogger metric_logging.TensorBoardLogger metric_logging.StdoutLogger diff --git a/docs/source/examples/configs.rst b/docs/source/examples/configs.rst index 023c14241f..5c00459c05 100644 --- a/docs/source/examples/configs.rst +++ b/docs/source/examples/configs.rst @@ -4,16 +4,15 @@ Configs Deep-Dive ================= -This tutorial will guide you through writing configs for running recipes and -designing params for custom recipes. +This tutorial will guide you through writing configs for running recipes. .. grid:: 2 .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn * How to write a YAML config and run a recipe with it - * How to create a params dataclass for custom recipe - * How to effectively use configs, CLI overrides, and dataclasses for running recipes + * How to use :code:`instantiate` and :code:`parse` APIs + * How to effectively use configs and CLI overrides for running recipes .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites @@ -27,137 +26,193 @@ Where do parameters live? There are two primary entry points for you to configure parameters: **configs** and **CLI overrides**. Configs are YAML files that define all the -parameters needed to run a recipe within a single location. These can be overridden on the -command-line for quick changes and experimentation without modifying the config. +parameters needed to run a recipe within a single location. They are the single +source of truth for reproducing a run. The config parameters can be overridden on the +command-line using :code:`tune` for quick changes and experimentation without +modifying the config. -If you are planning to make a custom recipe, you will need to become familiar -with the **recipe dataclass**, which collects all of your arguments from config and -CLI, and passes it into the recipe itself. Here, we will discuss all three concepts: -**configs**, **CLI**, and **dataclasses**. +Writing configs +--------------- +Configs serve as the primary entry point for running recipes in TorchTune. They are +expected to be YAML files and they simply list out values for parameters you want to define +for a particular run. -Recipe dataclasses ------------------- +.. code-block:: yaml -Parameters should be organized in a single dataclass that is passed into the recipe. -This serves as a single source of truth for the details of a fine-tuning run that can be easily validated in code and shared with collaborators for reproducibility. + seed: null + shuffle: True + device: cuda + dtype: fp32 + enable_fsdp: True + ... -.. code-block:: python +Configuring components using :code:`instantiate` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Many fields will require specifying TorchTune objects with associated keyword +arguments as parameters. Models, datasets, optimizers, and loss functions are +common examples of this. You can easily do this using the :code:`_component_` +subfield. In :code:`_component_`, you need to specify the dotpath of the object +you wish to instantiate in the recipe. The dotpath is the exact path you would use +to import the object normally in a Python file. For example, to specify the +:class:`~torchtune.datasets._alpaca.AlpacaDataset` in your config with custom +arguments: - class FullFinetuneParams: - # Model - model: str = "" - model_checkpoint: str = "" +.. code-block:: yaml -In the dataclass, all fields should have defaults assigned to them. -If a reasonable value cannot be assigned or it is a required argument, -use the null value for that data type as the default and ensure that it is set -by the user in the :code:`__post_init__` (see :ref:`Parameter Validation`). -The dataclass should go in the :code:`recipes/params/` folder and the name of -the file should match the name of the recipe file you are creating. + dataset: + _component_: torchtune.datasets.AlpacaDataset + train_on_input: False -In general, you should expose the minimal amount of parameters you need to run and experiment with your recipes. -Exposing an excessive number of parameters will lead to bloated configs, which are more error prone, harder to read, and harder to manage. -On the other hand, hardcoding all parameters will prevent quick experimentation without a code change. Only parametrize what is needed. +Here, we are changing the default value for :code:`train_on_input` from :code:`True` +to :code:`False`. -To link the dataclass object with config and CLI parsing, -you can use the :class:`~torchtune.utils.argparse.TuneArgumentParser` object and -funnel the parsed arguments into your dataclass. +Once you've specified the :code:`_component_` in your config, you can create an +instance of the specified object in your recipe's setup like so: .. code-block:: python - if __name__ == "__main__": - parser = utils.TuneArgumentParser( - description=FullFinetuneParams.__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - # Get user-specified args from config and CLI and create params for recipe - args, _ = parser.parse_known_args() - args = vars(args) - params = FullFinetuneParams(**args) + from torchtune import config - logger = utils.get_logger("DEBUG") - logger.info(msg=f"Running finetune_llm.py with parameters {params}") + # Access the dataset field and create the object instance + dataset = config.instantiate(cfg.dataset) - recipe(params) +This will automatically use any keyword arguments specified in the fields under +:code:`dataset`. -.. _parameter_validation_label: - -Parameter validation --------------------- -To validate arguments for your dataclass and recipe, use the :code:`__post_init__` method to house any checks and raised exceptions. +As written, the preceding example will actually throw an error. If you look at the constructor for :class:`~torchtune.datasets._alpaca.AlpacaDataset`, +you'll notice that we're missing a required positional argument, the tokenizer. +Since this is another configurable TorchTune object, let's understand how to handle +this by taking a look at the :func:`~torchtune.config._instantiate.instantiate` API. .. code-block:: python - def __post_init__(self): - for param in fields(self): - if getattr(self, param.name) == "": - raise TypeError(f"{param.name} needs to be specified") + def instantiate( + config: DictConfig, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], + ) -Writing configs ---------------- -Once you've set up a recipe and its params, you need to create a config to run it. -Configs serve as the primary entry point for running recipes in TorchTune. They are -expected to be YAML files and simply list out values for parameters you want to define -for a particular run. The config parameters should be a subset of the dataclass parameters; -there should not be any new fields that are not already in the dataclass. Any parameters that -are not specified in the config will take on the default value defined in the dataclass. +:func:`~torchtune.config._instantiate.instantiate` also accepts positional arguments +and keyword arguments and automatically uses that with the config when creating +the object. This means we can not only pass in the tokenizer, but also add additional +keyword arguments not specified in the config if we'd like: .. code-block:: yaml - dataset: alpaca - seed: null - shuffle: True - model: llama2_7b - ... + # Tokenizer is needed for the dataset, configure it first + tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/tokenizer.model -Command-line overrides ----------------------- -To enable quick experimentation, you can specify override values to parameters in your config -via the :code:`tune` command. These should be specified with the flag :code:`--override k1=v1 k2=v2 ...` + dataset: + _component_: torchtune.datasets.AlpacaDataset + train_on_input: True -For example, to run the :code:`full_finetune` recipe with custom model and tokenizer directories and using GPUs, you can provide overrides: +.. code-block:: python -.. code-block:: bash + # Note the API of the tokenizer we specified - we need to pass in a path + def llama2_tokenizer(path: str) -> Tokenizer; + + # Note the API of the dataset we specified - we need to pass in a tokenizer + # and any optional keyword arguments + class AlpacaDataset(Dataset): + def __init__( + self, + tokenizer: Tokenizer, + train_on_input: bool = True, + use_clean: bool = False, + **kwargs, + ) -> None; + + from torchtune import config + + # Since we've already specified the path in the config, we don't need to pass + # it in + tokenizer = config.instantiate(cfg.tokenizer) + # We pass in the instantiated tokenizer as the first required argument, then + # we change an optional keyword argument + dataset = config.instantiate( + cfg.dataset, + tokenizer, + use_clean=True, + ) + +Note that additional keyword arguments will overwrite any duplicated keys in the +config. + +Referencing other config fields with interpolations +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Sometimes you need to use the same value more than once for multiple fields. You +can use *interpolations* to reference another field, and :func:`~torchtune.config._instantiate.instantiate` +will automatically resolve it for you. - tune full_finetune --config alpaca_llama2_full_finetune --override model_directory=/home/my_model_checkpoint tokenizer_directory=/home/my_tokenizer_checkpoint device=cuda +.. code-block:: yaml -The order of overrides from these parameter sources is as follows, with highest precedence first: CLI, Config, Dataclass defaults + output_dir: /tmp/alpaca-llama2-finetune + metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} -Testing configs ---------------- -If you plan on contributing your config to the repo, we recommend adding it to the testing suite. TorchTune has testing for every config added to the library, namely ensuring that it instantiates the dataclass and runs the recipe correctly. +Best practices for writing configs +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Let's discuss some guidelines for writing configs to get the most out of them. -To add your config to this test suite, simply update the dictionary in :code:`recipes/tests/configs/test_configs`. +Airtight configs +"""""""""""""""" +While it may be tempting to put as much as you can in the config to give you +maximum flexibility in switching parameters for your experiments, we encourage +you to only include fields in the config that will be used or instantiated in the +recipe. This ensures full clarity on the options a recipe was run with and will +make it significantly easier to debug. -.. code-block:: python +.. code-block:: yaml + + # dont do this + alpaca_dataset: + _component_: torchtune.datasets.AlpacaDataset + train_on_input: True + slimorca_dataset: + ... + + # do this + dataset: + # change this in config or override when needed + _component_: torchtune.datasets.AlpacaDataset + train_on_input: True + +Use public APIs only +"""""""""""""""""""" +If a component you wish to specify in a config is located in a private file, use +the public dotpath in your config. These components are typically exposed in their +parent module's :code:`__init__.py` file. This way, you can guarantee the stability +of the API you are using in your config. There should be no underscores in your +component dotpath. - config_to_params = { - os.path.join(ROOT_DIR, "alpaca_llama2_full_finetune.yaml"): FullFinetuneParams, - ..., - } +.. code-block:: yaml -Linking recipes and configs with :code:`tune` ---------------------------------------------- + # don't do this + dataset: + _component_: torchtune.datasets._alpaca.AlpacaDataset + train_on_input: True -In order to run your custom recipe and configs with :code:`tune`, you must update the :code:`_RECIPE_LIST` -and :code:`_CONFIG_LISTS` in :code:`recipes/__init__.py` + # do this + dataset: + _component_: torchtune.datasets.AlpacaDataset + train_on_input: True -.. code-block:: python - _RECIPE_LIST = ["full_finetune", "lora_finetune", "alpaca_generate", ...] - _CONFIG_LISTS = { - "full_finetune": ["alpaca_llama2_full_finetune"], - "lora_finetune": ["alpaca_llama2_lora_finetune"], - "alpaca_generate": [], - "": [" --config --override ... + tune full_finetune --config alpaca_llama2_full_finetune --override model_directory=/home/my_model_checkpoint tokenizer_directory=/home/my_tokenizer_checkpoint device=cuda diff --git a/docs/source/examples/finetune_llm.rst b/docs/source/examples/finetune_llm.rst index 1169a7a50f..c7b9e09eb3 100644 --- a/docs/source/examples/finetune_llm.rst +++ b/docs/source/examples/finetune_llm.rst @@ -28,23 +28,29 @@ An example config for training the Llama 7B model using the Alpaca dataset looks .. code-block:: yaml - # Dataset and Dataloader - dataset: alpaca - seed: null + # Tokenizer + tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/tokenizer.model + + # Dataset + dataset: + _component_: torchtune.datasets.AlpacaDataset shuffle: True # Model Arguments - model: llama2_7b + model: + _component_: torchtune.models.llama2.llama2_7b model_checkpoint: /tmp/llama2-7b - tokenizer: llama2_tokenizer - tokenizer_checkpoint: /tmp/tokenizer.model # Fine-tuning arguments batch_size: 2 - lr: 2e-5 epochs: 3 - optimizer: SGD - loss: CrossEntropyLoss + optimizer: + _component_: torch.optim.SGD + lr: 2e-5 + loss: + _component_: torch.nn.CrossEntropyLoss output_dir: /tmp/alpaca-llama2-finetune device: cuda dtype: fp32 @@ -68,7 +74,7 @@ from Stanford. The following parameters are related to the data: # Point the dataset to the Alpaca Dataset implementation in TorchTune # This is set in the config - dataset: alpaca + dataset: AlpacaDataset # Don't mask the prompt during training # This is the default value diff --git a/docs/source/examples/first_finetune_tutorial.rst b/docs/source/examples/first_finetune_tutorial.rst index fe965a34e4..87b2cc302b 100644 --- a/docs/source/examples/first_finetune_tutorial.rst +++ b/docs/source/examples/first_finetune_tutorial.rst @@ -95,23 +95,30 @@ lowering the epochs to 1 so you can see results sooner, and updating the learnin .. code-block:: yaml - # Dataset and Dataloader - dataset: alpaca + # Tokenizer + tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/tokenizer.model + + # Dataset + dataset: + _component_: torchtune.datasets.AlpacaDataset seed: 42 shuffle: True # Model Arguments - model: llama2_7b + model: + _component_: torchtune.models.llama2.llama2_7b model_checkpoint: /tmp/llama2/native_pytorch_model.pt - tokenizer: llama2_tokenizer - tokenizer_checkpoint: /tmp/llama2/tokenizer.model # Fine-tuning arguments batch_size: 2 - lr: 1e-5 epochs: 1 - optimizer: SGD - loss: CrossEntropyLoss + optimizer: + _component_: torch.optim.SGD + lr: 1e-5 + loss: + _component_: torch.nn.CrossEntropyLoss output_dir: /tmp/alpaca-llama2-finetune device: cuda dtype: fp32 diff --git a/docs/source/examples/lora_finetune.rst b/docs/source/examples/lora_finetune.rst index 131575161d..2bf4b88b5d 100644 --- a/docs/source/examples/lora_finetune.rst +++ b/docs/source/examples/lora_finetune.rst @@ -272,10 +272,11 @@ Let's take a closer look at some of the :code:`alpaca_llama2_lora_finetune` conf .. code-block:: yaml # Model Arguments - model: lora_llama2_7b - lora_attn_modules: ['q_proj', 'v_proj'] - lora_rank: 8 - lora_alpha: 16 + model: + _component_: lora_llama2_7b + lora_attn_modules: ['q_proj', 'v_proj'] + lora_rank: 8 + lora_alpha: 16 ... We see that the default is to apply LoRA to Q and V projections with a rank of 8. diff --git a/docs/source/examples/recipe_deepdive.rst b/docs/source/examples/recipe_deepdive.rst index 692f1aec3d..fa9c9d031e 100644 --- a/docs/source/examples/recipe_deepdive.rst +++ b/docs/source/examples/recipe_deepdive.rst @@ -188,3 +188,28 @@ Running Recipes with Configs To run a recipe with a set of user-defined parameters, you will need to write a config file. You can learn all about configs in our :ref:`config tutorial`. + +Config and CLI parsing using :code:`parse` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +We provide a convenient decorator :func:`~torchtune.config._parse.parse` that wraps +your recipe to enable running from the command-line with :code:`tune` with config +and CLI override parsing. + +.. code-block:: python + + @config.parse + def recipe_main(cfg: DictConfig) -> None: + recipe = FullFinetuneRecipe(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +Running your recipe +^^^^^^^^^^^^^^^^^^^ +You should be able to run your recipe by providing the direct paths to your custom +recipe and custom config using the :code:`tune` command: + +.. code-block:: bash + + tune --config --override ... diff --git a/docs/source/index.rst b/docs/source/index.rst index 19dfe1665d..0487c7fa5d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -99,6 +99,7 @@ TorchTune tutorials. :caption: API Reference :hidden: + api_ref_config api_ref_datasets api_ref_models api_ref_modules diff --git a/recipes/__init__.py b/recipes/__init__.py index 2004bf54f1..f33e23c09f 100644 --- a/recipes/__init__.py +++ b/recipes/__init__.py @@ -5,21 +5,25 @@ # LICENSE file in the root directory of this source tree. -_RECIPE_LIST = ["full_finetune", "lora_finetune", "alpaca_generate"] +_RECIPE_LIST = [ + "full_finetune", + "lora_finetune", + "alpaca_generate", +] _CONFIG_LISTS = { "full_finetune": ["alpaca_llama2_full_finetune"], "lora_finetune": ["alpaca_llama2_lora_finetune"], - "alpaca_generate": [], + "alpaca_generate": ["alpaca_llama2_generate"], } def list_recipes(): - """List of availabe recipes available from the CLI""" + """List of recipes available from the CLI""" return _RECIPE_LIST def list_configs(recipe: str): - """List of availabe configs available from the CLI given a recipe""" + """List of configs available from the CLI given a recipe""" if recipe not in _CONFIG_LISTS: raise ValueError(f"Unknown recipe: {recipe}") return _CONFIG_LISTS[recipe] diff --git a/recipes/alpaca_generate.py b/recipes/alpaca_generate.py index d1ae817d08..d9e576773e 100644 --- a/recipes/alpaca_generate.py +++ b/recipes/alpaca_generate.py @@ -5,9 +5,10 @@ # LICENSE file in the root directory of this source tree. import torch +from omegaconf import DictConfig -from torchtune import models -from torchtune.utils import get_device, get_logger, set_seed, TuneArgumentParser +from torchtune import config +from torchtune.utils import get_device, get_logger, set_seed from torchtune.utils.generation import GenerationUtils # From https://github.com/tatsu-lab/stanford_alpaca/blob/761dc5bfbdeeffa89b8bff5d038781a4055f796a/train.py#L31 @@ -26,22 +27,16 @@ def recipe( - model, - model_checkpoint, - tokenizer, - tokenizer_checkpoint, - instruction, - input, - max_gen_len, + cfg: DictConfig, ): logger = get_logger("DEBUG") # Inference setup - tokenizer = models.get_tokenizer(tokenizer, path=tokenizer_checkpoint) + tokenizer = config.instantiate(cfg.tokenizer) - example = {"instruction": instruction} - if input != "": - example["input"] = input + example = {"instruction": cfg.instruction} + if cfg.input != "": + example["input"] = cfg.input prompt = PROMPT_DICT["prompt_input"].format_map(example) else: prompt = PROMPT_DICT["prompt_no_input"].format_map(example) @@ -52,10 +47,11 @@ def recipe( device = get_device() - decoder = models.get_model(model, device=device, max_batch_size=1) + with device: + decoder = config.instantiate(cfg.model, max_batch_size=1) # Load state_dict into decoder - native_state_dict = torch.load(model_checkpoint, weights_only=True) + native_state_dict = torch.load(cfg.model_checkpoint, weights_only=True) missing, unexpected = decoder.load_state_dict(native_state_dict, strict=False) decoder.eval() @@ -69,7 +65,7 @@ def recipe( prompt_tokens=token_for_generation, incremental_decode=True, min_gen_len=1, - max_gen_len=max_gen_len, + max_gen_len=cfg.max_gen_len, top_p=0, top_k=1, temperature=1.0, @@ -80,52 +76,6 @@ def recipe( logger.info(msg=generated_tokens[0]) -if __name__ == "__main__": - parser = TuneArgumentParser(description="Example 7B native Llama-2 inference.") - parser.add_argument( - "--model", - type=str, - default="llama2_7b", - choices=models.list_models(), - help="Name of the model to finetune.", - ) - parser.add_argument( - "--model-checkpoint", - type=str, - default="/tmp/llama2-7b", - help="Path to native checkpoint file.", - ) - parser.add_argument( - "--tokenizer", - type=str, - default="llama2_tokenizer", - choices=models.list_tokenizers(), - help="Name of the model tokenizer.", - ) - parser.add_argument( - "--tokenizer-checkpoint", - type=str, - default="/tmp/tokenizer.model", - help="Path to tokenization file.", - ) - parser.add_argument( - "--instruction", - type=str, - default="Answer the question.", - help="Instruction for model to respond to.", - ) - parser.add_argument( - "--input", - type=str, - default="What is some cool music from the 1920s?", - help='Additional optional input related to instruction. Pass in "" (empty string) for no input.', - ) - parser.add_argument( - "--max-gen-len", - type=int, - default=64, - help="Max number of tokens to generate", - ) - - kwargs = vars(parser.parse_args()) - recipe(**kwargs) +@config.parse +def main(cfg: DictConfig) -> None: + recipe(cfg) diff --git a/recipes/configs/alpaca_llama2_full_finetune.yaml b/recipes/configs/alpaca_llama2_full_finetune.yaml index d2479e3a48..5598ed5021 100644 --- a/recipes/configs/alpaca_llama2_full_finetune.yaml +++ b/recipes/configs/alpaca_llama2_full_finetune.yaml @@ -1,28 +1,48 @@ -# Runs the full_finetune.py recipe using FullFinetuneParams +# Config for FullFinetuneRecipe in full_finetune.py # # To launch, run the following command from root: # tune --nnodes 1 --nproc_per_node 1 full_finetune --config alpaca_llama2_full_finetune --override model_checkpoint= ... -# Dataset and Dataloader -dataset: alpaca +# Tokenizer +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/llama2/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.AlpacaDataset + train_on_input: True seed: null shuffle: True # Model Arguments -model: llama2_7b +model: + _component_: torchtune.models.llama2.llama2_7b model_checkpoint: /tmp/llama2_native -tokenizer: llama2_tokenizer -tokenizer_checkpoint: /tmp/llama2/tokenizer.model # Fine-tuning arguments batch_size: 2 -lr: 2e-5 epochs: 3 -optimizer: SGD -loss: CrossEntropyLoss -output_dir: /tmp/alpaca-llama2-finetune +optimizer: + _component_: torch.optim.SGD + lr: 2e-5 +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +log_every_n_steps: null +run_generation: null +resume_from_checkpoint: False + +# Distributed device: cuda dtype: fp32 enable_fsdp: True enable_activation_checkpointing: True -resume_from_checkpoint: False +cpu_offload: False + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-llama2-finetune diff --git a/recipes/configs/alpaca_llama2_generate.yaml b/recipes/configs/alpaca_llama2_generate.yaml new file mode 100644 index 0000000000..2c4a3f9781 --- /dev/null +++ b/recipes/configs/alpaca_llama2_generate.yaml @@ -0,0 +1,19 @@ +# Config alpaca_generate.py recipe +# +# To launch, run the following command from root: +# tune alpaca_generate --config alpaca_llama2_generate --override model_checkpoint= tokenizer_checkpoint= + +# Model arguments +model: + _component_: torchtune.models.llama2.llama2_7b +model_checkpoint: /tmp/llama2_native + +# Tokenizer arguments +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/llama2/tokenizer.model + +# Generation arguments +instruction: "Answer the question." +input: "What is some cool music from the 1920s?" +max_gen_len: 64 diff --git a/recipes/configs/alpaca_llama2_lora_finetune.yaml b/recipes/configs/alpaca_llama2_lora_finetune.yaml index 4b41c79b6d..02f41f8aa2 100644 --- a/recipes/configs/alpaca_llama2_lora_finetune.yaml +++ b/recipes/configs/alpaca_llama2_lora_finetune.yaml @@ -1,36 +1,57 @@ +# Config for LoRAFinetuneRecipe in lora_finetune.py +# +# To launch, run the following command from root: +# tune --nnodes 1 --nproc_per_node 1 lora_finetune --config alpaca_llama2_lora_finetune --override model_checkpoint= ... + # Model Arguments -model: lora_llama2_7b +model: + _component_: torchtune.models.llama2.lora_llama2_7b + lora_attn_modules: ['q_proj', 'v_proj'] + lora_rank: 8 + lora_alpha: 16 + model_checkpoint: /tmp/llama2_native -lora_attn_modules: ['q_proj', 'v_proj'] -lora_rank: 8 -lora_alpha: 16 lora_checkpoint: null # Tokenizer -tokenizer: llama2_tokenizer -tokenizer_checkpoint: /tmp/llama2/tokenizer.model +tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/llama2/tokenizer.model # Dataset and Sampler -dataset: alpaca -train_on_input: True +dataset: + _component_: torchtune.datasets.AlpacaDataset + train_on_input: True + use_clean: True +seed: null shuffle: True batch_size: 2 # Optimizer and Scheduler -optimizer: AdamW -weight_decay: 0.01 -lr: 3e-4 -lr_scheduler: cosine_with_warmup -num_warmup_steps: 100 -loss: CrossEntropyLoss +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torch.nn.CrossEntropyLoss # Training epochs: 1 resume_from_checkpoint: False +# Logging +output_dir: /tmp/lora_finetune_output +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 + # Environment device: cuda dtype: fp32 - -# Logging -output_dir: /tmp/lora_finetune_output +enable_fsdp: True +enable_activation_checkpointing: True diff --git a/recipes/full_finetune.py b/recipes/full_finetune.py index 4aa1f3c085..ee047e6c84 100644 --- a/recipes/full_finetune.py +++ b/recipes/full_finetune.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import argparse import os import sys @@ -13,9 +12,9 @@ from warnings import warn import torch +from omegaconf import DictConfig from recipes.interfaces import FTRecipeInterface -from recipes.params.full_finetune import FullFinetuneParams from torch import nn from torch.cuda.amp import GradScaler @@ -23,7 +22,7 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import datasets, models, modules, utils +from torchtune import config, modules, utils from torchtune.utils.constants import ( EPOCHS_KEY, MAX_STEPS_KEY, @@ -57,40 +56,41 @@ class FullFinetuneRecipe(FTRecipeInterface): - Training happens on CUDA (CPU training is not supported) - Checkpoints are ONLY saved at epoch boundaries. Mid-epoch checkpointing is NOT supported. - Datasets are Map-style and data fits in memory (not streamed). + + The following configs can be used to run this recipe: + >>> tune ls + RECIPE CONFIG + full_finetune alpaca_llama2_full_finetune + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file """ - def __init__(self, params: FullFinetuneParams) -> None: + def __init__(self, cfg: DictConfig) -> None: - self._device = utils.get_device(device=params.device) - self._dtype = utils.get_dtype(dtype=params.dtype) + self._device = utils.get_device(device=cfg.device) + self._dtype = utils.get_dtype(dtype=cfg.dtype) # logging attributes - self._output_dir = params.output_dir - self._metric_logger = utils.get_metric_logger( - metric_logger_type=params.metric_logger_type, - project=params.project, - log_dir=params.output_dir, - ) - self._log_every_n_steps = ( - params.log_every_n_steps if params.log_every_n_steps else 1 - ) + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1 # _is_rank_zero is used primarily for logging. In the future, the logger # should directly take care of this _, rank = utils.get_world_size_and_rank() self._is_rank_zero = rank == 0 - # Training params - self._resume_from_checkpoint = params.resume_from_checkpoint - self._enable_fsdp = params.enable_fsdp - self._gradient_accumulation_steps = params.gradient_accumulation_steps + # Training cfg + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._enable_fsdp = cfg.enable_fsdp + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests - self.seed = utils.set_seed(seed=params.seed) + self.seed = utils.set_seed(seed=cfg.seed) self.epochs_run = 0 - self.total_epochs = params.epochs - self.max_steps_per_epoch = params.max_steps_per_epoch + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch self.total_training_steps = 0 def load_checkpoint(self, ckpt_path: str): @@ -101,13 +101,14 @@ def load_checkpoint(self, ckpt_path: str): utils.validate_checkpoint(ckpt_dict, self._resume_from_checkpoint) return ckpt_dict - def setup(self, params: FullFinetuneParams) -> None: + def setup(self, cfg: DictConfig) -> None: """ Sets up the recipe state correctly. This includes setting recipe attributes based on the ``resume_from_checkpoint`` flag. """ + self._metric_logger = config.instantiate(cfg.metric_logger) - ckpt_dict = self.load_checkpoint(ckpt_path=params.model_checkpoint) + ckpt_dict = self.load_checkpoint(ckpt_path=cfg.model_checkpoint) # If we're resuming from checkpoint, the recipe's state should be updated before # initializing the training components. This ensures that the seed is correctly @@ -119,40 +120,40 @@ def setup(self, params: FullFinetuneParams) -> None: # should be called before ``_setup_optimizer`` since transforming the optimizer # state dict requires the model self._model = self._setup_model( - model=params.model, - enable_fsdp=params.enable_fsdp, - enable_activation_checkpointing=params.enable_activation_checkpointing, + cfg_model=cfg.model, + enable_fsdp=cfg.enable_fsdp, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, model_state_dict=ckpt_dict[MODEL_KEY], ) - self._tokenizer = self._setup_tokenizer( - tokenizer=params.tokenizer, tokenizer_checkpoint=params.tokenizer_checkpoint - ) + self._tokenizer = config.instantiate(cfg.tokenizer) + if self._is_rank_zero: + log.info("Tokenizer is initialized from file.") # _setup_optimizer should take in ckpt_dict only if training is resumed from # checkpoint. Transforming the opt state dict is handled by this method self._optimizer = self._setup_optimizer( - optimizer=params.optimizer, - lr=params.lr, + cfg_optimizer=cfg.optimizer, opt_state_dict=ckpt_dict[OPT_KEY] if self._resume_from_checkpoint else None, ) - self._loss_fn = self._setup_loss(loss=params.loss) + self._loss_fn = config.instantiate(cfg.loss) + if self._is_rank_zero: + log.info("Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after both of these are initialized self._sampler, self._dataloader = self._setup_data( - dataset=params.dataset, - train_on_input=params.train_on_input, - shuffle=params.shuffle, - batch_size=params.batch_size, + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, ) # training setup self._autocast = utils.get_autocast(self._dtype, self._device) self._grad_scaler = None if self._dtype == torch.float16: - self._grad_scaler = utils.get_gradient_scaler(fsdp=params.enable_fsdp) + self._grad_scaler = utils.get_gradient_scaler(fsdp=cfg.enable_fsdp) else: self._grad_scaler = GradScaler(enabled=False) @@ -195,7 +196,7 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: def _setup_model( self, - model: str, + cfg_model: DictConfig, enable_fsdp: bool, enable_activation_checkpointing: bool, model_state_dict: Dict[str, Any], @@ -205,7 +206,9 @@ def _setup_model( ``enable_fsdp`` should always be ``True``. This is currently a configurable flag for running tests on CPUs. """ - model = models.get_model(model, device=self._device) + with self._device: + model = config.instantiate(cfg_model) + model = ( utils.wrap_fsdp( model=model, @@ -228,28 +231,14 @@ def _setup_model( log.info("Model is initialized.") return model - def _setup_tokenizer( - self, tokenizer: str, tokenizer_checkpoint: str - ) -> modules.Tokenizer: - """ - Unlike ```setup_model```, this takes in the checkpoint and loads the sentencepiece - tokenizer model. This is related to how the tokenizer is implemented and should - change in a future iteration. - """ - tokenizer = models.get_tokenizer(tokenizer, path=tokenizer_checkpoint) - - if self._is_rank_zero: - log.info("Tokenizer is initialized from file.") - return tokenizer - def _setup_optimizer( - self, optimizer: str, lr: float, opt_state_dict: Optional[Dict[str, Any]] = None + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None ) -> Optimizer: """ Set up the optimizer. This method also handles transforing the state dict for FSDP. """ - optimizer = modules.get_optimizer(optimizer, self._model, lr) + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) if opt_state_dict: opt_state_dict = utils.transform_opt_state_dict( opt_state_dict, self._model, optimizer @@ -260,16 +249,11 @@ def _setup_optimizer( log.info("Optimizer is initialized.") return optimizer - def _setup_loss(self, loss: str) -> nn.Module: - loss_fn = modules.get_loss(loss) - - if self._is_rank_zero: - log.info("Loss is initialized.") - - return loss_fn - def _setup_data( - self, dataset: str, shuffle: bool, batch_size: int, train_on_input: bool + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, ) -> Tuple[DistributedSampler, DataLoader]: """ All data related setup happens here. Currently this recipe only supports the @@ -277,11 +261,9 @@ def _setup_data( iterable datasets and streaming datasets are not supported. """ world_size, rank = utils.get_world_size_and_rank() - ds = datasets.get_dataset( - dataset, - split="train", + ds = config.instantiate( + cfg_dataset, tokenizer=self._tokenizer, - train_on_input=train_on_input, ) sampler = DistributedSampler( ds, @@ -422,28 +404,20 @@ def cleanup(self) -> None: self._metric_logger.close() -def recipe_main() -> None: +@config.parse +def recipe_main(cfg: DictConfig) -> None: """ Entry point for the recipe. Configurable parameters are read in the following order: - - Parameters specified in ``FullFinetuneParams`` - - Overwritten by Parameters specified in ``alpaca_llama2_full_finetune.yaml`` - - Overwritten by arguments from the command-line using ``TuneArgumentParser`` + - Parameters specified in ``alpaca_llama2_full_finetune.yaml`` + - Overwritten by arguments from the command-line using ``--override`` """ - parser = utils.TuneArgumentParser( - description=FullFinetuneParams.__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - args, _ = parser.parse_known_args() - args = vars(args) - recipe_params = FullFinetuneParams(**args) - # Env variables set by torch run; only need to initialize process group init_process_group(backend="nccl") - recipe = FullFinetuneRecipe(params=recipe_params) - recipe.setup(params=recipe_params) + recipe = FullFinetuneRecipe(cfg=cfg) + recipe.setup(cfg=cfg) recipe.train() recipe.cleanup() diff --git a/recipes/lora_finetune.py b/recipes/lora_finetune.py index c82b339d12..7e13097a77 100644 --- a/recipes/lora_finetune.py +++ b/recipes/lora_finetune.py @@ -4,25 +4,24 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import argparse import os import sys from functools import partial -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Optional, Tuple from warnings import warn import torch +from omegaconf import DictConfig from recipes.interfaces import FTRecipeInterface -from recipes.params.lora_finetune import LoRAFinetuneParams from torch import nn from torch.cuda.amp import GradScaler from torch.distributed import init_process_group from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import datasets, models, modules, utils +from torchtune import config, modules, utils from torchtune.modules.peft.lora import reset_lora_params from torchtune.modules.peft.peft_utils import ( get_adapter_params, @@ -63,12 +62,20 @@ class LoRAFinetuneRecipe(FTRecipeInterface): in ongoing epoch is lost. - Datasets are Map-style and data fits in memory (not streamed). + The following configs can be used to run this recipe: + >>> tune ls + RECIPE CONFIG + lora_finetune alpaca_llama2_lora_finetune + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + """ - def __init__(self, params: LoRAFinetuneParams) -> None: + def __init__(self, cfg: DictConfig) -> None: - self._device = utils.get_device(device=params.device) - self._dtype = utils.get_dtype(dtype=params.dtype) + self._device = utils.get_device(device=cfg.device) + self._dtype = utils.get_dtype(dtype=cfg.dtype) # _is_rank_zero is used primarily for logging. In the future, the logger # should directly take care of this @@ -76,38 +83,33 @@ def __init__(self, params: LoRAFinetuneParams) -> None: self._is_rank_zero = rank == 0 # logging attributes - self._output_dir = params.output_dir - self._log_every_n_steps = ( - params.log_every_n_steps if params.log_every_n_steps else 1 - ) - if self._is_rank_zero: - self._metric_logger = utils.get_metric_logger( - metric_logger_type=params.metric_logger_type, - project=params.project, - log_dir=params.output_dir, - ) + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.log_every_n_steps if cfg.log_every_n_steps else 1 # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests - self.seed = utils.set_seed(seed=params.seed) + self.seed = utils.set_seed(seed=cfg.seed) self.epochs_run = 0 - self.total_epochs = params.epochs - self.max_steps_per_epoch = params.max_steps_per_epoch + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch self.total_training_steps = 0 - self._resume_from_checkpoint = params.resume_from_checkpoint + self._resume_from_checkpoint = cfg.resume_from_checkpoint - def setup(self, params: LoRAFinetuneParams) -> None: + def setup(self, cfg: DictConfig) -> None: """ Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. """ + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + # Load in base model weights # Note that we set resume_from_checkpoint=False when loading the base model. # This is because we only save LoRA weights during training, so only lora_checkpoint # will contain training state, while model_checkpoint contains model weights only. base_model_ckpt = self.load_checkpoint( - ckpt_path=params.model_checkpoint, resume_from_checkpoint=False + ckpt_path=cfg.model_checkpoint, resume_from_checkpoint=False ) # If we're resuming from checkpoint, the recipe's state should be updated before @@ -115,53 +117,48 @@ def setup(self, params: LoRAFinetuneParams) -> None: # propagated to the relevant components if self._resume_from_checkpoint: assert ( - params.lora_checkpoint is not None + cfg.lora_checkpoint is not None ), "Must pass lora_checkpoint when resuming training" lora_ckpt = self.load_checkpoint( - ckpt_path=params.lora_checkpoint, resume_from_checkpoint=True + ckpt_path=cfg.lora_checkpoint, resume_from_checkpoint=True ) self._update_recipe_state(lora_ckpt) self._model = self._setup_model( - model=params.model, - lora_attn_modules=params.lora_attn_modules, - lora_rank=params.lora_rank, - lora_alpha=params.lora_alpha, - enable_fsdp=params.enable_fsdp, - enable_activation_checkpointing=params.enable_activation_checkpointing, + cfg_model=cfg.model, + enable_fsdp=cfg.enable_fsdp, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, base_model_state_dict=base_model_ckpt[MODEL_KEY], lora_weights_state_dict=lora_ckpt[MODEL_KEY] if self._resume_from_checkpoint else None, ) - self._tokenizer = self._setup_tokenizer( - tokenizer=params.tokenizer, tokenizer_checkpoint=params.tokenizer_checkpoint - ) + self._tokenizer = config.instantiate(cfg.tokenizer) + if self._is_rank_zero: + log.info("Tokenizer is initialized from file.") self._optimizer = self._setup_optimizer( - optimizer=params.optimizer, - lr=params.lr, - weight_decay=params.weight_decay, + cfg_optimizer=cfg.optimizer, opt_state_dict=lora_ckpt[OPT_KEY] if self._resume_from_checkpoint else None, ) - self._loss_fn = self._setup_loss(loss=params.loss) + self._loss_fn = config.instantiate(cfg.loss) + if self._is_rank_zero: + log.info("Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after all of these are setup self._sampler, self._dataloader = self._setup_data( - dataset=params.dataset, - shuffle=params.shuffle, - batch_size=params.batch_size, - train_on_input=params.train_on_input, - use_clean=params.use_clean, + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, ) # training setup self._autocast = utils.get_autocast(self._dtype, self._device) if self._dtype == torch.float16: - self._grad_scaler = utils.get_gradient_scaler(fsdp=params.enable_fsdp) + self._grad_scaler = utils.get_gradient_scaler(fsdp=cfg.enable_fsdp) else: self._grad_scaler = GradScaler(enabled=False) @@ -182,8 +179,7 @@ def setup(self, params: LoRAFinetuneParams) -> None: # Learning rate scheduler can only be set up after number of steps # has been computed self._lr_scheduler = self._setup_lr_scheduler( - lr_scheduler=params.lr_scheduler, - num_warmup_steps=params.num_warmup_steps, + cfg_lr_scheduler=cfg.lr_scheduler, num_training_steps=self.total_epochs * steps_per_epoch, last_epoch=self.total_training_steps - 1, ) @@ -218,10 +214,7 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: def _setup_model( self, - model: str, - lora_attn_modules: List[str], - lora_rank: int, - lora_alpha: float, + cfg_model: DictConfig, enable_fsdp: bool, enable_activation_checkpointing: bool, base_model_state_dict: Dict[str, Any], @@ -229,14 +222,9 @@ def _setup_model( ) -> nn.Module: # LoRA recipe uses meta device for FSDP init to avoid peak memory reserved # during model init - init_device = "meta" if enable_fsdp else self._device - model = models.get_model( - model, - device=init_device, - lora_attn_modules=lora_attn_modules, - lora_rank=lora_rank, - lora_alpha=lora_alpha, - ) + init_device = torch.device("meta") if enable_fsdp else self._device + with init_device: + model = config.instantiate(cfg_model) reset_lora_params(model, device=self._device) @@ -265,7 +253,7 @@ def _setup_model( ) validate_state_dict_for_lora( - lora_modules=lora_attn_modules, + lora_modules=cfg_model.lora_attn_modules, full_model_state_dict_keys=model.state_dict().keys(), lora_state_dict_keys=lora_weights_state_dict.keys() if lora_weights_state_dict is not None @@ -280,31 +268,13 @@ def _setup_model( log.info("Model is initialized.") return model - def _setup_tokenizer( - self, tokenizer: str, tokenizer_checkpoint: str - ) -> modules.Tokenizer: - """ - Unlike ```setup_model```, this takes in the checkpoint and loads the sentencepiece - tokenizer model. This is related to how the tokenizer is implemented and should - change in a future iteration. - """ - tokenizer = models.get_tokenizer(tokenizer, path=tokenizer_checkpoint) - - if self._is_rank_zero: - log.info("Tokenizer is initialized from file.") - return tokenizer - def _setup_optimizer( - self, - optimizer: str, - lr: float, - weight_decay: float, - opt_state_dict: Optional[Dict[str, Any]] = None, + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None ) -> Optimizer: - optimizer = modules.get_optimizer(optimizer, self._model, lr, weight_decay) + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) if opt_state_dict: # Note: technically we should check _contains_fsdp for - # just the state dict of the adapter params, but should be equivalent + # just the state dict of the adapter cfg, but should be equivalent opt_state_dict = utils.transform_opt_state_dict( opt_state_dict, self._model, optimizer ) @@ -316,15 +286,13 @@ def _setup_optimizer( def _setup_lr_scheduler( self, - lr_scheduler: str, - num_warmup_steps: int, + cfg_lr_scheduler: DictConfig, num_training_steps: int, last_epoch: int, ) -> Optimizer: - lr_scheduler = modules.get_lr_scheduler( - lr_scheduler, + lr_scheduler = config.instantiate( + cfg_lr_scheduler, self._optimizer, - num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch, ) @@ -332,21 +300,11 @@ def _setup_lr_scheduler( log.info("Learning rate scheduler is initialized.") return lr_scheduler - def _setup_loss(self, loss: str) -> nn.Module: - loss_fn = modules.get_loss(loss) - - if self._is_rank_zero: - log.info("Loss is initialized.") - - return loss_fn - def _setup_data( self, - dataset: str, + cfg_dataset: DictConfig, shuffle: bool, batch_size: int, - train_on_input: bool, - use_clean: bool, ) -> Tuple[DistributedSampler, DataLoader]: """ All data related setup happens here. Currently this recipe only supports the @@ -354,11 +312,9 @@ def _setup_data( iterable datasets and streaming datasets are not supported. """ world_size, rank = utils.get_world_size_and_rank() - ds = datasets.get_dataset( - dataset, - split="train", + ds = config.instantiate( + cfg_dataset, tokenizer=self._tokenizer, - train_on_input=train_on_input, ) sampler = DistributedSampler( ds, @@ -476,28 +432,20 @@ def cleanup(self) -> None: self._metric_logger.close() -def recipe_main() -> None: +@config.parse +def recipe_main(cfg: DictConfig) -> None: """ Entry point for the recipe. Configurable parameters are read in the following order: - - Parameters specified in ``LoRAFinetuneParams`` - - Overwritten by Parameters specified in ``alpaca_llama2_lora_finetune.yaml`` - - Overwritten by arguments from the command-line using ``TuneArgumentParser`` + - Parameters specified in ``alpaca_llama2_lora_finetune.yaml`` + - Overwritten by arguments from the command-line using ``--override`` """ - parser = utils.TuneArgumentParser( - description=LoRAFinetuneParams.__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - args, _ = parser.parse_known_args() - args = vars(args) - recipe_params = LoRAFinetuneParams(**args) - # Env variables set by torch run; only need to initialize process group init_process_group(backend="nccl") - recipe = LoRAFinetuneRecipe(params=recipe_params) - recipe.setup(params=recipe_params) + recipe = LoRAFinetuneRecipe(cfg=cfg) + recipe.setup(cfg=cfg) recipe.train() recipe.cleanup() diff --git a/recipes/params/full_finetune.py b/recipes/params/full_finetune.py deleted file mode 100644 index 0e8ff1585f..0000000000 --- a/recipes/params/full_finetune.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass, fields -from typing import Optional - -from torchtune.datasets import ALL_DATASETS -from torchtune.models import ALL_MODELS, ALL_TOKENIZERS -from torchtune.utils.metric_logging import ALL_METRIC_LOGGERS -from torchtune.utils.precision import PRECISION_STR_TO_DTYPE - - -@dataclass -class FullFinetuneParams: - """Arguments for the finetune_llm recipe. - - Args: - device (str): Device to use for training. Options are "cpu" and "cuda" - dtype (str): Data type to use for training. - seed (int): Random seed to use for training. - model (str): String specifying model architecture to fine-tune. See ``torchtune.models.get_model`` for options. - model_checkpoint (str): Local path to load model checkpoint from. - tokenizer (str): String specifying tokenizer to use. See ``torchtune.models.get_tokenizer`` for options. - tokenizer_checkpoint (str): Local path to load tokenizer checkpoint from. - dataset (str): String specifying dataset to use. See ``torchtune.datasets.get_dataset`` for options. - Currently, only predefined datasets in library are supported. - shuffle (bool): Whether to shuffle dataset. - batch_size (int): Batch size to use for training. - epochs (int): Number of epochs to train for. - optimizer (str): String specifying optimizer to use. See ``torchtune.optim.get_optimizer`` for options. - loss (str): String specifying loss function to use. See ``torchtune.losses.get_loss`` for options. - lr (float): Learning rate to use for optimizer. - activation_checkpointing (bool): Whether to use activation checkpointing. - output_dir (str): Local path to save checkpoints and logs to. - run_generation (int): Run eval on a prompt every ``run_generation`` steps. Set to 0 to disable. - max_steps_per_epoch (int): Maximum number of steps to take per epoch. - metric_logger_type (str): String specifying metric logger to use. See ``torchtune.utils.get_metric_logger`` - for options. - project (str): Project name to use for logging. Used by ``WandBLogger``. - resume_from_previous_checkpoint (bool): Whether to resume fine-tuning from a previous checkpoint. - cpu_offload (bool): Whether to offload model to CPU. - - Raises: - ValueError: If ``cpu_offload`` is ``True`` but ``device`` is not ``cuda`` and <= 1 GPUs. - """ - - # Model - model: str = "" - model_checkpoint: str = "" - - # Tokenizer - tokenizer: str = "" - tokenizer_checkpoint: str = "" - - # Dataset and Sampler - dataset: str = "" - train_on_input: bool = True - shuffle: bool = True - batch_size: int = 2 - - # Optimizer and Scheduler - optimizer: str = "SGD" - lr: float = 2e-5 - loss: str = "CrossEntropyLoss" - gradient_accumulation_steps: int = 1 - - # Training - epochs: int = 3 - max_steps_per_epoch: Optional[int] = None - resume_from_checkpoint: bool = False - run_generation: Optional[int] = None - - # Distributed - cpu_offload: bool = False - enable_fsdp: bool = True - enable_activation_checkpointing: bool = True - - # Environment - device: str = "cuda" - dtype: str = "fp32" - seed: Optional[int] = None - - # Logging - output_dir: str = "/tmp/full_finetune_output" - metric_logger_type: str = "disk" - project: Optional[str] = None - log_every_n_steps: Optional[int] = None - - def __post_init__(self): - for param in fields(self): - if getattr(self, param.name) == "": - raise TypeError(f"{param.name} needs to be specified") - - if self.cpu_offload and self.device != "cuda": - raise ValueError( - "Cannot offload model to CPU if device is not cuda or <= 1 GPUs." - ) - if self.enable_fsdp and self.device == "cpu": - raise ValueError("FSDP is not supported on CPU.") - if self.model not in ALL_MODELS: - raise ValueError( - f"Model not recognized. Expected one of {ALL_MODELS}, received {self.model}." - ) - if self.tokenizer not in ALL_TOKENIZERS: - raise ValueError( - f"Tokenizer not recognized. Expected one of {ALL_TOKENIZERS}, received {self.tokenizer}." - ) - if self.dataset not in ALL_DATASETS: - raise ValueError( - f"Dataset not recognized. Expected one of {ALL_DATASETS}, received {self.dataset}." - ) - if self.metric_logger_type not in ALL_METRIC_LOGGERS: - raise ValueError( - f"Metric logger not recognized. Expected one of {ALL_METRIC_LOGGERS}, received {self.metric_logger_type}." - ) - if self.dtype not in PRECISION_STR_TO_DTYPE: - raise ValueError( - f"Dtype {self.dtype} must be one of {', '.join(PRECISION_STR_TO_DTYPE.keys())} for finetuning." - ) diff --git a/recipes/params/lora_finetune.py b/recipes/params/lora_finetune.py deleted file mode 100644 index e49f0273ae..0000000000 --- a/recipes/params/lora_finetune.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass, field, fields -from typing import List, Optional - -from torchtune.datasets import ALL_DATASETS -from torchtune.models import ALL_MODELS, ALL_TOKENIZERS -from torchtune.utils.metric_logging import ALL_METRIC_LOGGERS -from torchtune.utils.precision import PRECISION_STR_TO_DTYPE - - -@dataclass -class LoRAFinetuneParams: - """Arguments for the finetune_lora recipe. Note that LoRA is currently only supported - for attention modules (i.e. Q, K, V, output projections), and not for MLP layers. - - Args: - model (str): String specifying model architecture to fine-tune. See ``torchtune.models.get_model`` for options. - model_checkpoint (str): Local path to load model checkpoint from. - lora_attn_modules (List[str]): List of attention modules to use for LoRA. Supported values are - ["q_proj", "k_proj", "v_proj", "output_proj"]. - lora_rank (int): Rank of LoRA decompositions. - lora_alpha (float): Alpha parameter for LoRA. - lora_checkpoint (str): Local path to load LoRA weights from. - tokenizer (str): String specifying tokenizer to use. See ``torchtune.models.get_tokenizer`` for options. - tokenizer_checkpoint (str): Local path to load tokenizer checkpoint from. - dataset (str): String specifying dataset to use. See ``torchtune.datasets.get_dataset`` for options. - Currently, only predefined datasets in library are supported. - train_on_input (bool): Whether to train on the prompt in addition to the response. - use_clean (bool): Whether to use cleaned version of Alpaca dataset or not. - shuffle (bool): Whether to shuffle dataset. - batch_size (int): Batch size to use for training. - epochs (int): Number of epochs to train for. - optimizer (str): String specifying optimizer to use. See ``torchtune.optim.get_optimizer`` for options. - weight_decay (float): Weight decay to use for optimizer. - lr (float): Base learning rate rate to use for optimizer. - lr_scheduler (str): String specifying learning rate scheduler to use. See - ``torchtune.lr_schedulers.get_lr_scheduler`` for options. - num_warmup_steps (int): Number of warmup steps to use for learning rate scheduler. - loss (str): String specifying loss function to use. See ``torchtune.losses.get_loss`` for options. - epochs (int): Number of epochs to train for. - max_steps_per_epoch (int): Maximum number of steps to take per epoch. - resume_from_checkpoint (bool): Whether to resume fine-tuning from a previous checkpoint. - cpu_offload (bool): Whether to offload model to CPU. - enable_fsdp (bool): Whether to use FSDP. - enable_activation_checkpointing (bool): Whether to use activation checkpointing. - device (str): Device to use for training. Options are "cpu" and "cuda" - dtype (str): Data type to use for training. - seed (int): Random seed to use for training. - output_dir (str): Local path to save checkpoints and logs to. - metric_logger_type (str): String specifying metric logger to use. See ``torchtune.utils.get_metric_logger`` - for options. - project (str): Project name to use for logging. Used by ``WandBLogger``. - log_every_n_steps (int): How often to log metrics. - """ - - # Model - model: str = "" - model_checkpoint: str = "" - lora_attn_modules: List[str] = field(default_factory=list) - lora_rank: int = 8 - lora_alpha: float = 16 - lora_checkpoint: Optional[str] = None - - # Tokenizer - tokenizer: str = "" - tokenizer_checkpoint: str = "" - - # Dataset and Sampler - dataset: str = "" - train_on_input: bool = True - use_clean: bool = True - shuffle: bool = True - batch_size: int = 2 - - # Optimizer and Scheduler - optimizer: str = "AdamW" - weight_decay: float = 0.01 - lr: float = 3e-4 - lr_scheduler: str = "cosine_with_warmup" - num_warmup_steps: int = 100 - loss: str = "CrossEntropyLoss" - - # Training - epochs: int = 1 - max_steps_per_epoch: Optional[int] = None - resume_from_checkpoint: bool = False - - # Distributed - cpu_offload: bool = False - enable_fsdp: bool = True - enable_activation_checkpointing: bool = True - - # Environment - device: str = "cuda" - dtype: str = "fp32" - seed: Optional[int] = None - - # Logging - output_dir: str = "/tmp/lora_finetune_output" - metric_logger_type: str = "disk" - project: Optional[str] = None - log_every_n_steps: Optional[int] = None - - def __post_init__(self): - for param in fields(self): - if getattr(self, param.name) == "": - raise TypeError(f"{param.name} needs to be specified") - - if self.cpu_offload and self.device != "cuda": - raise ValueError( - "Cannot offload model to CPU if device is not cuda or <= 1 GPUs." - ) - if self.enable_fsdp and self.device == "cpu": - raise ValueError("FSDP is not supported on CPU.") - if self.model not in ALL_MODELS: - raise ValueError( - f"Model not recognized. Expected one of {ALL_MODELS}, received {self.model}." - ) - if self.tokenizer not in ALL_TOKENIZERS: - raise ValueError( - f"Tokenizer not recognized. Expected one of {ALL_TOKENIZERS}, received {self.tokenizer}." - ) - if self.dataset not in ALL_DATASETS: - raise ValueError( - f"Dataset not recognized. Expected one of {ALL_DATASETS}, received {self.dataset}." - ) - if self.metric_logger_type not in ALL_METRIC_LOGGERS: - raise ValueError( - f"Metric logger not recognized. Expected one of {ALL_METRIC_LOGGERS}, received {self.metric_logger_type}." - ) - if self.dtype not in PRECISION_STR_TO_DTYPE: - raise ValueError( - f"Dtype {self.dtype} must be one of {', '.join(PRECISION_STR_TO_DTYPE.keys())} for finetuning." - ) - if len(self.lora_attn_modules) == 0: - raise ValueError("Must specify at least one module to apply LoRA to") diff --git a/recipes/tests/configs/test_configs.py b/recipes/tests/configs/test_configs.py index 66393288ad..ab259f8b1b 100644 --- a/recipes/tests/configs/test_configs.py +++ b/recipes/tests/configs/test_configs.py @@ -6,16 +6,17 @@ import os import pytest +from omegaconf import OmegaConf +from recipes.full_finetune import FullFinetuneRecipe +from recipes.lora_finetune import LoRAFinetuneRecipe -from recipes.params.full_finetune import FullFinetuneParams -from recipes.params.lora_finetune import LoRAFinetuneParams from torchtune.utils.argparse import TuneArgumentParser ROOT_DIR: str = os.path.join(os.path.abspath(__file__), "../../../configs") -config_to_params = { - os.path.join(ROOT_DIR, "alpaca_llama2_full_finetune.yaml"): FullFinetuneParams, - os.path.join(ROOT_DIR, "alpaca_llama2_lora_finetune.yaml"): LoRAFinetuneParams, +config_to_recipe = { + os.path.join(ROOT_DIR, "alpaca_llama2_full_finetune.yaml"): FullFinetuneRecipe, + os.path.join(ROOT_DIR, "alpaca_llama2_lora_finetune.yaml"): LoRAFinetuneRecipe, } @@ -29,12 +30,17 @@ def parser(self): parser = TuneArgumentParser("Test parser") return parser + # TODO: update this test to run recipes with debug args, disabling for now + @pytest.mark.skip( + reason="Need to update to use debug args after config system is finalized." + ) def test_configs(self, parser) -> None: - for config_path, params in config_to_params.items(): + for config_path, recipe in config_to_recipe.items(): args, _ = parser.parse_known_args(["--config", config_path]) try: - _ = params(**vars(args)) + cfg = OmegaConf.create(vars(args)) + recipe(cfg) except ValueError as e: raise AssertionError( - f"Config {config_path} using params {params.__name__} is not well formed" + f"Config {config_path} for recipe {recipe.__name__} is not well formed" ) from e diff --git a/recipes/tests/test_alpaca_generate.py b/recipes/tests/test_alpaca_generate.py index 612371722a..08a7d5fdd8 100644 --- a/recipes/tests/test_alpaca_generate.py +++ b/recipes/tests/test_alpaca_generate.py @@ -8,6 +8,7 @@ from typing import Optional import recipes.alpaca_generate as alpaca_generate +from omegaconf import OmegaConf from torchtune import models from torchtune.models.llama2 import llama2 from torchtune.modules import TransformerDecoder @@ -26,7 +27,7 @@ def small_test_ckpt(max_batch_size: Optional[int] = None) -> TransformerDecoder: ) -models.ALL_MODELS["small_test_ckpt"] = small_test_ckpt +models.small_test_ckpt = small_test_ckpt logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -35,21 +36,24 @@ class TestAlpacaGenerateRecipe: def _fetch_ckpt_model_path(self, ckpt) -> str: if ckpt == "small_test_ckpt": return "/tmp/test-artifacts/small-ckpt-01242024" - if ckpt == "llama2_7b": + if ckpt == "llama2.llama2_7b": return "/tmp/test-artifacts/llama2-7b-01242024" raise ValueError(f"Unknown ckpt {ckpt}") def test_alpaca_generate(self, capsys, pytestconfig): large_scale = pytestconfig.getoption("--large-scale") - ckpt = "llama2_7b" if large_scale else "small_test_ckpt" + ckpt = "llama2.llama2_7b" if large_scale else "small_test_ckpt" kwargs_values = { - "model": ckpt, + "model": {"_component_": f"torchtune.models.{ckpt}"}, "model_checkpoint": self._fetch_ckpt_model_path(ckpt), - "tokenizer": "llama2_tokenizer", - "tokenizer_checkpoint": "/tmp/test-artifacts/tokenizer.model", + "tokenizer": { + "_component_": "torchtune.models.llama2.llama2_tokenizer", + "path": "/tmp/test-artifacts/tokenizer.model", + }, "instruction": "Answer the question.", "input": "What is some cool music from the 1920s?", "max_gen_len": 64, } - alpaca_generate.recipe(**kwargs_values) + cfg = OmegaConf.create(kwargs_values) + alpaca_generate.recipe(cfg) diff --git a/recipes/tests/test_full_finetune.py b/recipes/tests/test_full_finetune.py index e9b181dd52..7f4c849a85 100644 --- a/recipes/tests/test_full_finetune.py +++ b/recipes/tests/test_full_finetune.py @@ -13,8 +13,8 @@ import pytest import torch +from omegaconf import OmegaConf from recipes.full_finetune import FullFinetuneRecipe -from recipes.params.full_finetune import FullFinetuneParams from recipes.tests.utils import ( default_recipe_kwargs, fetch_ckpt_model_path, @@ -28,7 +28,7 @@ from torchtune.datasets._alpaca import CROSS_ENTROPY_IGNORE_IDX from torchtune.utils.collate import padded_collate -models.ALL_MODELS["small_test_ckpt"] = llama2_small_test_ckpt +models.small_test_ckpt = llama2_small_test_ckpt logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -51,21 +51,21 @@ def _fetch_expected_loss_values(self, ckpt) -> Dict[str, float]: } if ckpt == "small_test_ckpt": return small_test_ckpt_loss_values - if ckpt == "llama2_7b": + if ckpt == "llama2.llama2_7b": return llama2_7b_ckpt_loss_values raise ValueError(f"Unknown ckpt {ckpt}") def test_loss(self, capsys, pytestconfig): large_scale = pytestconfig.getoption("--large-scale") - ckpt = "llama2_7b" if large_scale else "small_test_ckpt" + ckpt = "llama2.llama2_7b" if large_scale else "small_test_ckpt" expected_loss_values = self._fetch_expected_loss_values(ckpt) kwargs_values = default_recipe_kwargs(ckpt) - recipe_params = FullFinetuneParams(**kwargs_values) + recipe_cfg = OmegaConf.create(kwargs_values) - recipe = FullFinetuneRecipe(recipe_params) - recipe.setup(params=recipe_params) + recipe = FullFinetuneRecipe(recipe_cfg) + recipe.setup(cfg=recipe_cfg) recipe.train() loss_values = fetch_loss_values(capsys.readouterr().err) @@ -84,26 +84,32 @@ def test_training_state_on_resume(self): expected_loss_values = self._fetch_expected_loss_values(model_ckpt) with tempfile.TemporaryDirectory() as tmpdirname: - kwargs_values = { - "dataset": "alpaca", - "seed": 9, - "shuffle": True, - "model": model_ckpt, - "model_checkpoint": fetch_ckpt_model_path(model_ckpt), - "tokenizer": "llama2_tokenizer", - "tokenizer_checkpoint": "/tmp/test-artifacts/tokenizer.model", - "epochs": 4, - "max_steps_per_epoch": 2, - "output_dir": tmpdirname, - "device": "cpu", - "resume_from_checkpoint": False, - "enable_fsdp": False, - } - - recipe_params = FullFinetuneParams(**kwargs_values) - - recipe = FullFinetuneRecipe(recipe_params) - recipe.setup(params=recipe_params) + kwargs_values = default_recipe_kwargs(model_ckpt) + kwargs_values.update( + { + "dataset": {"_component_": "torchtune.datasets.AlpacaDataset"}, + "seed": 9, + "shuffle": True, + "model": {"_component_": f"torchtune.models.{model_ckpt}"}, + "model_checkpoint": fetch_ckpt_model_path(model_ckpt), + "tokenizer": { + "_component_": "torchtune.models.llama2.llama2_tokenizer", + "path": "/tmp/test-artifacts/tokenizer.model", + }, + "epochs": 4, + "max_steps_per_epoch": 2, + "output_dir": tmpdirname, + "device": "cpu", + "resume_from_checkpoint": False, + "enable_fsdp": False, + "dtype": "fp32", + } + ) + + recipe_cfg = OmegaConf.create(kwargs_values) + + recipe = FullFinetuneRecipe(recipe_cfg) + recipe.setup(cfg=recipe_cfg) recipe.train() recipe.cleanup() @@ -111,24 +117,31 @@ def test_training_state_on_resume(self): # check if these are correctly inferred from the checkpoint # Note this will raise some warnings in the logs, but is a # stronger test - kwargs_values_resume = { - "dataset": "alpaca", - "shuffle": True, - "model": model_ckpt, - "model_checkpoint": os.path.join(tmpdirname, "model_2.ckpt"), - "tokenizer": "llama2_tokenizer", - "tokenizer_checkpoint": "/tmp/test-artifacts/tokenizer.model", - "epochs": 4, - "output_dir": tmpdirname, - "device": "cpu", - "resume_from_checkpoint": True, # set to True to resume - "enable_fsdp": False, - } - - recipe_params = FullFinetuneParams(**kwargs_values_resume) - - recipe = FullFinetuneRecipe(recipe_params) - recipe.setup(params=recipe_params) + kwargs_values_resume = deepcopy(kwargs_values) + kwargs_values_resume.update( + { + "dataset": {"_component_": "torchtune.datasets.AlpacaDataset"}, + "seed": None, + "max_steps_per_epoch": None, + "shuffle": True, + "model": {"_component_": f"torchtune.models.{model_ckpt}"}, + "model_checkpoint": os.path.join(tmpdirname, "model_2.ckpt"), + "tokenizer": { + "_component_": "torchtune.models.llama2.llama2_tokenizer", + "path": "/tmp/test-artifacts/tokenizer.model", + }, + "epochs": 4, + "output_dir": tmpdirname, + "device": "cpu", + "resume_from_checkpoint": True, # set to True to resume + "enable_fsdp": False, + } + ) + + recipe_cfg = OmegaConf.create(kwargs_values_resume) + + recipe = FullFinetuneRecipe(recipe_cfg) + recipe.setup(cfg=recipe_cfg) assert recipe.epochs_run == 3 assert recipe.seed == kwargs_values["seed"] @@ -169,12 +182,13 @@ def forward(self, x): def dummy_grad_accum_ckpt(): - model = DummyModel() - fixed_init_model(model) + with torch.device("cpu"): + model = DummyModel() + fixed_init_model(model) return model -models.ALL_MODELS["dummy_grad_accum_ckpt"] = dummy_grad_accum_ckpt +models.dummy_grad_accum_ckpt = dummy_grad_accum_ckpt @pytest.fixture @@ -205,42 +219,49 @@ def test_gradient_accumulation( model_ckpt = "dummy_grad_accum_ckpt" gradient_accumulation_steps = full_batch_size // micro_batch_size kwargs_values = { - "dataset": "alpaca", - "train_on_input": False, + "dataset": { + "_component_": "torchtune.datasets.AlpacaDataset", + "train_on_input": False, + }, "seed": 9, "shuffle": True, - "model": model_ckpt, + "model": {"_component_": f"torchtune.models.{model_ckpt}"}, "model_checkpoint": None, - "tokenizer": "llama2_tokenizer", - "tokenizer_checkpoint": "/tmp/test-artifacts/tokenizer.model", + "tokenizer": { + "_component_": "torchtune.models.llama2.llama2_tokenizer", + "path": "/tmp/test-artifacts/tokenizer.model", + }, "batch_size": full_batch_size, - "lr": 2e-5, "epochs": 1, # make sure to run for 1 epoch "max_steps_per_epoch": 1, - "optimizer": "AdamW", - "loss": "CrossEntropyLoss", + "optimizer": {"_component_": "torch.optim.AdamW", "lr": 2e-5}, + "loss": {"_component_": "torch.nn.CrossEntropyLoss"}, "output_dir": "/tmp", "device": "cpu", "dtype": "fp32", "resume_from_checkpoint": False, "enable_fsdp": False, "enable_activation_checkpointing": False, - "metric_logger_type": "disk", + "metric_logger": { + "_component_": "torchtune.utils.metric_logging.DiskLogger", + "log_dir": "${output_dir}", + }, "gradient_accumulation_steps": 1, + "log_every_n_steps": None, } # First run without gradient accumulation baseline_params = kwargs_values.copy() - baseline_recipe_params = FullFinetuneParams(**baseline_params) - baseline_recipe = FullFinetuneRecipe(baseline_recipe_params) + baseline_recipe_cfg = OmegaConf.create(baseline_params) + baseline_recipe = FullFinetuneRecipe(baseline_recipe_cfg) # Patch the recipe to use DummyModel class # Note that this cannot be done via a decorator because we use patch two separate times with mocker.patch( "recipes.full_finetune.FullFinetuneRecipe._setup_model", - return_value=models.get_model("dummy_grad_accum_ckpt", device="cpu"), + return_value=dummy_grad_accum_ckpt(), ): - baseline_recipe.setup(params=baseline_recipe_params) + baseline_recipe.setup(cfg=baseline_recipe_cfg) baseline_recipe.train() # the first run assumes the complete batch and so we have a single loss value @@ -255,16 +276,16 @@ def test_gradient_accumulation( grad_accum_params = kwargs_values.copy() grad_accum_params["batch_size"] = micro_batch_size grad_accum_params["gradient_accumulation_steps"] = gradient_accumulation_steps - grad_accum_recipe_params = FullFinetuneParams(**grad_accum_params) - grad_accum_recipe = FullFinetuneRecipe(grad_accum_recipe_params) + grad_accum_recipe_cfg = OmegaConf.create(grad_accum_params) + grad_accum_recipe = FullFinetuneRecipe(grad_accum_recipe_cfg) # Patch the recipe to use DummyModel class. We use a separate patch # because otherwise the model params would remain the same from the baseline with mocker.patch( "recipes.full_finetune.FullFinetuneRecipe._setup_model", - return_value=models.get_model("dummy_grad_accum_ckpt", device="cpu"), + return_value=dummy_grad_accum_ckpt(), ): - grad_accum_recipe.setup(params=grad_accum_recipe_params) + grad_accum_recipe.setup(cfg=grad_accum_recipe_cfg) # Copy the dataloader and run a few iterations. CrossEntropyLoss is normalized # by the number of unmasked tokens, so we need to derive these values per sample diff --git a/recipes/tests/test_lora_finetune.py b/recipes/tests/test_lora_finetune.py index 3a909f8ef8..24796872c0 100644 --- a/recipes/tests/test_lora_finetune.py +++ b/recipes/tests/test_lora_finetune.py @@ -9,8 +9,9 @@ from functools import partial from typing import Dict +from omegaconf import OmegaConf + from recipes.lora_finetune import LoRAFinetuneRecipe -from recipes.params.lora_finetune import LoRAFinetuneParams from recipes.tests.utils import ( default_recipe_kwargs, @@ -22,7 +23,7 @@ from torchtune import models test_lora_attn_modules = ["q_proj", "k_proj", "v_proj", "output_proj"] -models.ALL_MODELS["lora_small_test_ckpt"] = partial( +models.lora_small_test_ckpt = partial( lora_llama2_small_test_ckpt, lora_attn_modules=test_lora_attn_modules ) logging.basicConfig(level=logging.INFO) @@ -47,11 +48,17 @@ def test_loss(self, capsys, pytestconfig): ckpt = "lora_small_test_ckpt" expected_loss_values = self._fetch_expected_loss_values(ckpt) kwargs_values = default_recipe_kwargs(ckpt) - kwargs_values["lora_attn_modules"] = test_lora_attn_modules - recipe_params = LoRAFinetuneParams(**kwargs_values) + kwargs_values["model"].update( + { + "lora_attn_modules": test_lora_attn_modules, + "lora_rank": 8, + "lora_alpha": 16, + } + ) + recipe_cfg = OmegaConf.create(kwargs_values) - recipe = LoRAFinetuneRecipe(recipe_params) - recipe.setup(params=recipe_params) + recipe = LoRAFinetuneRecipe(recipe_cfg) + recipe.setup(cfg=recipe_cfg) recipe.train() loss_values = fetch_loss_values(capsys.readouterr().err) diff --git a/recipes/tests/test_params.py b/recipes/tests/test_params.py deleted file mode 100644 index 19d60185ac..0000000000 --- a/recipes/tests/test_params.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import pytest -from recipes.params.full_finetune import FullFinetuneParams - - -class TestParams: - @pytest.fixture - def params(self): - return dict( - dataset="alpaca", - seed=None, - shuffle=True, - model="llama2_7b", - model_checkpoint="/tmp/llama2-7b", - tokenizer="llama2_tokenizer", - tokenizer_checkpoint="/tmp/tokenizer.model", - batch_size=2, - lr=2e-5, - epochs=3, - optimizer="SGD", - loss="CrossEntropyLoss", - output_dir="/tmp/alpaca-llama2-finetune", - device="cuda", - dtype="fp32", - enable_activation_checkpointing=False, - enable_fsdp=False, - cpu_offload=False, - metric_logger_type="disk", - resume_from_checkpoint=False, - ) - - def test_bad_model(self, params): - with pytest.raises(ValueError): - params["model"] = "dummy" - _ = FullFinetuneParams(**params) - with pytest.raises(TypeError): - params["model"] = "" - _ = FullFinetuneParams(**params) - - def test_bad_dataset(self, params): - with pytest.raises(ValueError): - params["dataset"] = "dummy" - _ = FullFinetuneParams(**params) - - def test_bad_tokenizer(self, params): - with pytest.raises(ValueError): - params["tokenizer"] = "dummy" - _ = FullFinetuneParams(**params) - - def test_bad_dtype(self, params): - with pytest.raises(ValueError): - params["dtype"] = "dummy" - _ = FullFinetuneParams(**params) - - def test_bad_metric_logger(self, params): - with pytest.raises(ValueError): - params["metric_logger_type"] = "dummy" - _ = FullFinetuneParams(**params) - - def test_cpu_offload_without_cuda(self, params): - with pytest.raises(ValueError): - params["cpu_offload"] = True - params["device"] = "cpu" - _ = FullFinetuneParams(**params) - - def test_fsdp_not_on_cpu(self, params): - with pytest.raises(ValueError): - params["enable_fsdp"] = True - params["device"] = "cpu" - _ = FullFinetuneParams(**params) diff --git a/recipes/tests/utils.py b/recipes/tests/utils.py index 46c315aacc..db6625841a 100644 --- a/recipes/tests/utils.py +++ b/recipes/tests/utils.py @@ -80,25 +80,37 @@ def validate_loss_values(loss_values, expected_loss_values): def default_recipe_kwargs(ckpt): return { - "dataset": "alpaca", - "train_on_input": False, + "dataset": { + "_component_": "torchtune.datasets.AlpacaDataset", + "train_on_input": False, + }, "seed": 9, "shuffle": True, - "model": ckpt, + "model": {"_component_": f"torchtune.models.{ckpt}"}, "model_checkpoint": fetch_ckpt_model_path(ckpt), - "tokenizer": "llama2_tokenizer", - "tokenizer_checkpoint": "/tmp/test-artifacts/tokenizer.model", + "tokenizer": { + "_component_": "torchtune.models.llama2.llama2_tokenizer", + "path": "/tmp/test-artifacts/tokenizer.model", + }, "batch_size": 8, - "lr": 2e-5, "epochs": 2, "max_steps_per_epoch": 2, - "optimizer": "AdamW", - "loss": "CrossEntropyLoss", + "optimizer": {"_component_": "torch.optim.AdamW", "lr": 2e-5}, + "loss": {"_component_": "torch.nn.CrossEntropyLoss"}, "output_dir": "/tmp", "device": "cpu", "dtype": "fp32", "resume_from_checkpoint": False, "enable_fsdp": False, "enable_activation_checkpointing": False, - "metric_logger_type": "disk", + "metric_logger": { + "_component_": "torchtune.utils.metric_logging.DiskLogger", + "log_dir": "${output_dir}", + }, + "log_every_n_steps": None, + "gradient_accumulation_steps": 1, + "lr_scheduler": { + "_component_": "torchtune.modules.get_cosine_schedule_with_warmup", + "num_warmup_steps": 100, + }, } diff --git a/requirements.txt b/requirements.txt index 37b1663b1e..cb04a9c126 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ huggingface_hub==0.19.4 # Misc sentencepiece==0.1.99 tqdm==4.66.1 -omegaconf==2.3.0 +omegaconf>=2.3.0 # Evaluation lm_eval==0.4.1 diff --git a/tests/torchtune/config/test_instantiate.py b/tests/torchtune/config/test_instantiate.py new file mode 100644 index 0000000000..474ababd8c --- /dev/null +++ b/tests/torchtune/config/test_instantiate.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +from omegaconf import OmegaConf +from torchtune.config._instantiate import ( + _create_component, + _has_component, + _instantiate_node, + instantiate, +) +from torchtune.config._utils import InstantiationError +from torchtune.modules import RMSNorm + + +class TestInstantiate: + @pytest.fixture + def config(self): + s = """ + a: b + b: c + test: + _component_: torchtune.modules.RMSNorm + dim: 5 + """ + return OmegaConf.create(s) + + @pytest.fixture + def module(self): + return RMSNorm(dim=5, eps=1e-4) + + def get_dim(self, rms_norm: RMSNorm): + return rms_norm.scale.shape[0] + + def test_has_path(self, config): + assert _has_component(config.test) + assert not _has_component(config.a) + + def test_call_object(self, module): + obj = RMSNorm + args = (5,) + kwargs = {"eps": 1e-4} + actual = _create_component(obj, args, kwargs) + expected = module + assert isinstance(actual, RMSNorm) + assert self.get_dim(actual) == self.get_dim(expected) + assert actual.eps == expected.eps + + def test_instantiate_node(self, config, module): + actual = _instantiate_node(config.test) + expected = module + assert isinstance(actual, RMSNorm) + assert self.get_dim(actual) == self.get_dim(expected) + + with pytest.raises( + InstantiationError, match="Cannot instantiate specified object" + ): + _ = _instantiate_node(config.a) + + def test_instantiate(self, config, module): + actual = instantiate(config.test) + expected = module + assert isinstance(actual, RMSNorm) + assert self.get_dim(actual) == self.get_dim(expected) + + # Test passing in kwargs + actual = instantiate(config.test, eps=1e-4) + assert actual.eps == expected.eps + + # Test passing in positional args + del config.test.dim + actual = instantiate(config.test, 3) + assert self.get_dim(actual) == 3 diff --git a/tests/torchtune/config/test_parse.py b/tests/torchtune/config/test_parse.py new file mode 100644 index 0000000000..df4b334a8c --- /dev/null +++ b/tests/torchtune/config/test_parse.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import Namespace +from unittest.mock import patch + +import pytest +from torchtune import config + +_CONFIG = Namespace(a=1, b=2) + + +class TestParse: + def test_parse(self): + a = 1 + b = 3 + + @config.parse + def func(cfg): + assert cfg.a == a + assert cfg.b != b + + with patch( + "torchtune.config._parse.TuneArgumentParser.parse_known_args", + return_value=(_CONFIG, None), + ) as mock_parse_args: + with pytest.raises(SystemExit): + func() + mock_parse_args.assert_called_once() diff --git a/tests/torchtune/config/test_utils.py b/tests/torchtune/config/test_utils.py new file mode 100644 index 0000000000..6a27bfdfd4 --- /dev/null +++ b/tests/torchtune/config/test_utils.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +from torchtune.config._utils import _get_component_from_path, InstantiationError + + +class TestUtils: + def test_get_component_from_path(self): + good_paths = [ + "torchtune", # Test single module without dot + "torchtune.models", # Test dotpath for a module + "torchtune.models.llama2.llama2_7b", # Test dotpath for an object + ] + for path in good_paths: + _ = _get_component_from_path(path) + + # Test that a relative path fails + with pytest.raises(ValueError, match="Relative imports are not supported"): + _ = _get_component_from_path(".test") + # Test that a non-existent path fails + with pytest.raises( + InstantiationError, match="Error loading 'torchtune.models.dummy'" + ): + _ = _get_component_from_path("torchtune.models.dummy") diff --git a/tests/torchtune/datasets/test_alpaca_dataset.py b/tests/torchtune/datasets/test_alpaca_dataset.py index 552233e30d..4c7b9c1c04 100644 --- a/tests/torchtune/datasets/test_alpaca_dataset.py +++ b/tests/torchtune/datasets/test_alpaca_dataset.py @@ -10,8 +10,7 @@ from tests.test_utils import get_assets_path -from torchtune import datasets -from torchtune.datasets._alpaca import CROSS_ENTROPY_IGNORE_IDX +from torchtune.datasets._alpaca import AlpacaDataset, CROSS_ENTROPY_IGNORE_IDX from torchtune.modules.tokenizer import Tokenizer @@ -64,7 +63,7 @@ def test_prompt_generation(self, load_dataset, tokenizer): ), ] - alpaca_dataset = datasets.get_dataset("alpaca", tokenizer=tokenizer) + alpaca_dataset = AlpacaDataset(tokenizer=tokenizer) # alpaca_dataset._data contains the raw data loaded from HF's dataset. We need the raw data # to test the prompt generation since calling __getitem__ on the alpaca_dataset object will @@ -93,7 +92,7 @@ def test_label_no_masking(self, load_dataset, tokenizer): } ] - alpaca_dataset = datasets.get_dataset("alpaca", tokenizer=tokenizer) + alpaca_dataset = AlpacaDataset(tokenizer=tokenizer) input, labels = alpaca_dataset[0] assert len(input) == len(labels) @@ -120,9 +119,7 @@ def test_label_masking(self, load_dataset, tokenizer): } ] - alpaca_dataset = datasets.get_dataset( - "alpaca", tokenizer=tokenizer, train_on_input=False - ) + alpaca_dataset = AlpacaDataset(tokenizer=tokenizer, train_on_input=False) # Extract the prompt and tokenize it; we'll need this to test whether we're masking the # input correctly @@ -157,9 +154,7 @@ def test_alpaca_clean(self, load_dataset, tokenizer): } ] - alpaca_dataset = datasets.get_dataset( - "alpaca", tokenizer=tokenizer, use_clean=True - ) + alpaca_dataset = AlpacaDataset(tokenizer=tokenizer, use_clean=True) input, labels = alpaca_dataset[0] assert len(input) == len(labels) diff --git a/tests/torchtune/datasets/test_get_dataset.py b/tests/torchtune/datasets/test_get_dataset.py deleted file mode 100644 index 6a04c3dfed..0000000000 --- a/tests/torchtune/datasets/test_get_dataset.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from torchtune import datasets - - -class TestDatasetGetter: - def test_get_dataset(self): - """ - Test getting a named dataset - """ - datasets.ALL_DATASETS["test"] = lambda x: x - dataset = datasets.get_dataset("test", x=1) - assert dataset == 1 - - def test_list_datasets(self): - """ - Test accuracy of dataset list - """ - dataset_names = datasets.list_datasets() - assert "test" in dataset_names diff --git a/tests/torchtune/datasets/test_slimorca_dataset.py b/tests/torchtune/datasets/test_slimorca_dataset.py index c69aac2376..786dbf4f41 100644 --- a/tests/torchtune/datasets/test_slimorca_dataset.py +++ b/tests/torchtune/datasets/test_slimorca_dataset.py @@ -9,8 +9,7 @@ from tests.test_utils import get_assets_path -from torchtune import datasets -from torchtune.datasets._slimorca import _Llama2ChatFormatConstants +from torchtune.datasets._slimorca import _Llama2ChatFormatConstants, SlimOrcaDataset from torchtune.modules.tokenizer import Tokenizer @@ -24,7 +23,7 @@ def tokenizer(self): @patch("torchtune.datasets._slimorca.load_dataset") def test_prompt_label_generation(self, load_dataset, tokenizer): load_dataset.return_value = [] - dataset = datasets.get_dataset("slimorca", tokenizer=tokenizer) + dataset = SlimOrcaDataset(tokenizer=tokenizer) sample = [ { "from": "system", @@ -66,9 +65,7 @@ def test_prompt_label_generation(self, load_dataset, tokenizer): @patch("torchtune.datasets._slimorca.load_dataset") def test_token_generation(self, load_dataset, tokenizer): load_dataset.return_value = [] - dataset = datasets.get_dataset( - "slimorca", tokenizer=tokenizer, max_token_length=4096 - ) + dataset = SlimOrcaDataset(tokenizer=tokenizer, max_token_length=4096) input, label = dataset._generate_tokens("Hello ", "world!") assert input == [tokenizer.bos_id, 12, 1803, 1024, 103, tokenizer.eos_id] assert label == ([-100] * 3 + [1024, 103, tokenizer.eos_id]) @@ -76,9 +73,7 @@ def test_token_generation(self, load_dataset, tokenizer): @patch("torchtune.datasets._slimorca.load_dataset") def test_truncated_token_generation(self, load_dataset, tokenizer): load_dataset.return_value = [] - dataset = datasets.get_dataset( - "slimorca", tokenizer=tokenizer, max_token_length=5 - ) + dataset = SlimOrcaDataset(tokenizer=tokenizer, max_token_length=5) # 5 is enough for full prompt, but not for label input, label = dataset._generate_tokens("Hello ", "world!") assert input == [tokenizer.bos_id, 12, 1803, 1024, tokenizer.eos_id] @@ -86,9 +81,7 @@ def test_truncated_token_generation(self, load_dataset, tokenizer): # 4 is not enough for full prompt nor response but truncation # is still feasible - dataset = datasets.get_dataset( - "slimorca", tokenizer=tokenizer, max_token_length=4 - ) + dataset = SlimOrcaDataset(tokenizer=tokenizer, max_token_length=4) input, label = dataset._generate_tokens("Hello ", "world!") assert input == [tokenizer.bos_id, 12, 1024, tokenizer.eos_id] assert label == ([-100] * 2 + [1024, tokenizer.eos_id]) @@ -97,7 +90,7 @@ def test_truncated_token_generation(self, load_dataset, tokenizer): def test_value_error(self, load_dataset, tokenizer): load_dataset.return_value = [] with pytest.raises(ValueError): - datasets.get_dataset("slimorca", tokenizer=tokenizer, max_token_length=3) + SlimOrcaDataset(tokenizer=tokenizer, max_token_length=3) @patch("torchtune.datasets._slimorca.load_dataset") @pytest.mark.parametrize("max_token_length", [128, 512, 1024, 4096]) @@ -121,9 +114,7 @@ def test_dataset_get_item(self, load_dataset, tokenizer, max_token_length): ] } ] - ds = datasets.get_dataset( - "slimorca", tokenizer=tokenizer, max_token_length=max_token_length - ) + ds = SlimOrcaDataset(tokenizer=tokenizer, max_token_length=max_token_length) input, label = ds[0] assert len(input) <= max_token_length assert len(label) <= max_token_length diff --git a/tests/torchtune/models/test_get_model.py b/tests/torchtune/models/test_get_model.py deleted file mode 100644 index dd1c0504c5..0000000000 --- a/tests/torchtune/models/test_get_model.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from torchtune import models - - -class TestModelTokenizerGetter: - def test_get_model(self): - """ - Test getting a named model - """ - models.ALL_MODELS["test"] = lambda x: x - model = models.get_model("test", "cpu", x=1) - assert model == 1 - - def test_get_model_device(self): - models.ALL_MODELS["test"] = lambda x: x - model = models.get_model("test", device=torch.device("cpu"), x=1) - assert model == 1 - - def test_list_models(self): - """ - Test accuracy of model list - """ - model_names = models.list_models() - assert "test" in model_names - - def test_get_tokenizer(self): - """ - Test getting a named tokenizer - """ - models.ALL_TOKENIZERS["test"] = lambda x: x - tokenizer = models.get_tokenizer("test", x=1) - assert tokenizer == 1 - - def test_list_tokenizer(self): - """ - Test accuracy of tokenizer list - """ - tokenizer_names = models.list_tokenizers() - assert "test" in tokenizer_names diff --git a/tests/torchtune/utils/test_metric_logging.py b/tests/torchtune/utils/test_metric_logging.py index 667c396d04..40cad238f7 100644 --- a/tests/torchtune/utils/test_metric_logging.py +++ b/tests/torchtune/utils/test_metric_logging.py @@ -16,38 +16,12 @@ from torchtune.utils.metric_logging import ( DiskLogger, - get_metric_logger, - list_metric_loggers, StdoutLogger, TensorBoardLogger, WandBLogger, ) -class TestMetricLogger: - def test_list_metric_loggers(self) -> None: - assert set(list_metric_loggers()) == { - "disk", - "stdout", - "tensorboard", - "wandb", - } - - def test_get_metric_logger(self) -> None: - fake_kwargs = { - "log_dir": "/tmp/output", - "project": "test-project", - "extra_key": "bananas", - } - assert isinstance(get_metric_logger("disk", **fake_kwargs), DiskLogger) - assert isinstance(get_metric_logger("stdout", **fake_kwargs), StdoutLogger) - assert isinstance( - get_metric_logger("tensorboard", **fake_kwargs), TensorBoardLogger - ) - with patch("wandb.init") as wandb_init: - assert isinstance(get_metric_logger("wandb", **fake_kwargs), WandBLogger) - - class TestDiskLogger: def test_log(self) -> None: with tempfile.TemporaryDirectory() as log_dir: diff --git a/torchtune/config/__init__.py b/torchtune/config/__init__.py new file mode 100644 index 0000000000..932d2d30ca --- /dev/null +++ b/torchtune/config/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._instantiate import instantiate +from ._parse import parse + +__all__ = [ + "parse", + "instantiate", +] diff --git a/torchtune/config/_instantiate.py b/torchtune/config/_instantiate.py new file mode 100644 index 0000000000..838bb1af0d --- /dev/null +++ b/torchtune/config/_instantiate.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +from typing import Any, Callable, Dict, Tuple + +from omegaconf import DictConfig, OmegaConf +from torchtune.config._utils import _get_component_from_path, InstantiationError + + +def _has_component(node: DictConfig) -> bool: + return OmegaConf.is_dict(node) and "_component_" in node + + +def _create_component( + _component_: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any] +): + return _component_(*args, **kwargs) + + +def _instantiate_node(node: DictConfig, *args: Tuple[Any, ...]): + """ + Creates the object specified in _component_ field with provided positional args + and kwargs already merged. Raises an InstantiationError if _component_ is not specified. + """ + if _has_component(node): + _component_ = _get_component_from_path(node.get("_component_")) + kwargs = {k: v for k, v in node.items() if k != "_component_"} + return _create_component(_component_, args, kwargs) + else: + raise InstantiationError( + "Cannot instantiate specified object." + + "\nMake sure you've specified a _component_ field with a valid dotpath." + ) + + +def instantiate( + config: DictConfig, + *args: Tuple[Any, ...], + **kwargs: Dict[str, Any], +) -> Any: + """ + Given a DictConfig with a _component_ field specifying the object to instantiate and + additional fields for keyword arguments, create an instance of the specified object. + You can use this function to create the exact instance of a TorchTune object you want + to use in your recipe using the specification from the config. + + This function also supports passing in positional args and keyword args within the + function call. These are automatically merged with the provided config, with keyword + args taking precedence. + + Examples: + >>> config.yaml: + >>> model: + >>> _component_: torchtune.models.llama2 + >>> num_layers: 32 + >>> num_heads: 32 + >>> num_kv_heads: 32 + + >>> from torchtune import config + >>> vocab_size = 32000 + >>> # Pass in vocab size as positional argument. Since it is positioned first + >>> # in llama2(), it must be specified first. Pass in other arguments as kwargs. + >>> # This will return an nn.Module directly for llama2 with specified args. + >>> model = config.instantiate(parsed_yaml.model, vocab_size, max_seq_len=4096, embed_dim=4096) + + Args: + config (DictConfig): a single field in the OmegaConf object parsed from the yaml file. + This is expected to have a _component_ field specifying the path of the object + to instantiate. + *args (Tuple[Any, ...]): positional arguments to pass to the object to instantiate. + **kwargs (Dict[str, Any]): keyword arguments to pass to the object to instantiate. + + Returns: + Any: the instantiated object. + + Raises: + ValueError: if config is not a DictConfig. + """ + + # Return None if config is None + if config is None: + return None + if not OmegaConf.is_dict(config): + raise ValueError(f"instantiate only supports DictConfigs, got {type(config)}") + + config_copy = copy.deepcopy(config) + config_copy._set_flag( + flags=["allow_objects", "struct", "readonly"], values=[True, False, False] + ) + config_copy._set_parent(config._get_parent()) + config = config_copy + + if kwargs: + # This overwrites any repeated fields in the config with kwargs + config = OmegaConf.merge(config, kwargs) + + # Resolve all interpolations, or references to other fields within the same config + OmegaConf.resolve(config) + + return _instantiate_node(config, *args) diff --git a/torchtune/config/_parse.py b/torchtune/config/_parse.py new file mode 100644 index 0000000000..943f002e30 --- /dev/null +++ b/torchtune/config/_parse.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import argparse +import functools +import sys +from typing import Any, Callable + +from omegaconf import DictConfig, OmegaConf +from torchtune.utils.argparse import TuneArgumentParser +from torchtune.utils.logging import get_logger + + +Recipe = Callable[[DictConfig], Any] + + +def parse(recipe_main: Recipe) -> Callable[[Recipe], Any]: + """ + Decorator that handles parsing the config file and CLI overrides + for a recipe. Use it on the recipe's main function. + + Example: in recipe/my_recipe.py, + >>> @parse + >>> def main(cfg: DictConfig): + >>> ... + + With the decorator, the parameters will be parsed into cfg when run as: + >>> tune my_recipe --config config.yaml --override foo=bar + + Args: + recipe_main (Recipe): The main method that initializes + and runs the recipe + + Returns: + Callable[[Recipe], Any]: the decorated main + """ + + @functools.wraps(recipe_main) + def wrapper(*args: Any, **kwargs: Any) -> Any: + parser = TuneArgumentParser( + description=recipe_main.__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + # Get user-specified args from config and CLI and create params for recipe + params, _ = parser.parse_known_args() + params = OmegaConf.create(vars(params)) + + logger = get_logger("DEBUG") + logger.info(msg=f"Running {recipe_main.__name__} with parameters {params}") + + sys.exit(recipe_main(params)) + + return wrapper diff --git a/torchtune/config/_utils.py b/torchtune/config/_utils.py new file mode 100644 index 0000000000..c8b19d853a --- /dev/null +++ b/torchtune/config/_utils.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from importlib import import_module +from types import ModuleType +from typing import Any + + +class InstantiationError(Exception): + pass + + +def _get_component_from_path(path: str) -> Any: + """ + Return an object by name or dotted path, importing as necessary. + The base functionality relies on ``getattr()`` and handles all + possible exceptions accordingly. + + Args: + path (str): Dotted path of the object + + Returns: + Any: The object + + Raises: + ImportError: If the path is empty or there is an exception loading the + object from the provided path + ValueError: If a relative or invalid dotpath is passed in + """ + if path == "": + raise ValueError("Empty path") + + parts = [part for part in path.split(".")] + for part in parts: + # If a relative path is passed in, the first part will be empty + if not len(part): + raise ValueError( + f"Error loading '{path}': invalid dotstring." + + "\nRelative imports are not supported." + ) + # First module requires trying to import to validate + part0 = parts[0] + try: + obj = import_module(part0) + except ImportError as exc_import: + raise InstantiationError( + f"Error loading '{path}':\n{repr(exc_import)}" + + f"\nAre you sure that module '{part0}' is installed?" + ) from exc_import + # Subsequent components can be checked via getattr() on first module + # It can either be an attribute that we can return or a submodule that we + # can import and continue searching + for m in range(1, len(parts)): + part = parts[m] + try: + obj = getattr(obj, part) + # If getattr fails, check to see if it's a module we can import and + # continue down the path + except AttributeError as exc_attr: + parent_dotpath = ".".join(parts[:m]) + if isinstance(obj, ModuleType): + mod = ".".join(parts[: m + 1]) + try: + obj = import_module(mod) + continue + except ModuleNotFoundError as exc_import: + raise InstantiationError( + f"Error loading '{path}':\n{repr(exc_import)}" + + f"\nAre you sure that '{part}' is importable from module '{parent_dotpath}'?" + ) from exc_import + # Any other error trying to import module can be raised as + # InstantiationError + except Exception as exc_import: + raise InstantiationError( + f"Error loading '{path}':\n{repr(exc_import)}" + ) from exc_import + # If the component is not an attribute nor a module, it doesn't exist + raise InstantiationError( + f"Error loading '{path}':\n{repr(exc_attr)}" + + f"\nAre you sure that '{part}' is an attribute of '{parent_dotpath}'?" + ) from exc_attr + return obj diff --git a/torchtune/datasets/__init__.py b/torchtune/datasets/__init__.py index 6012f49c9c..4e40b65ed2 100644 --- a/torchtune/datasets/__init__.py +++ b/torchtune/datasets/__init__.py @@ -4,24 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torch.utils.data import Dataset - from ._alpaca import AlpacaDataset from ._slimorca import SlimOrcaDataset -ALL_DATASETS = {"alpaca": AlpacaDataset, "slimorca": SlimOrcaDataset} - - -def get_dataset(name: str, **kwargs) -> Dataset: - """Get known supported datasets by name""" - if name in ALL_DATASETS: - return ALL_DATASETS[name](**kwargs) - else: - raise ValueError( - f"Dataset not recognized. Expected one of {ALL_DATASETS}, received {name}" - ) - - -def list_datasets(): - """List of availabe datasets supported by `get_dataset`""" - return list(ALL_DATASETS) +__all__ = [ + "AlpacaDataset", + "SlimOrcaDataset", +] diff --git a/torchtune/losses.py b/torchtune/losses.py deleted file mode 100644 index b4290b095f..0000000000 --- a/torchtune/losses.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from torch import nn - - -def get_loss(loss: str) -> nn.Module: - """Returns a loss function from torch.nn. - - Args: - loss (str): name of the loss function. - - Returns: - nn.Module: loss function. - - Raises: - ValueError: if the loss is not a valid loss from torch.nn. - """ - try: - return getattr(nn, loss)() - except AttributeError as e: - raise ValueError(f"{loss} is not a valid loss from torch.nn") from e - - -# TODO convert to folder when we support llm specific losses diff --git a/torchtune/models/__init__.py b/torchtune/models/__init__.py index 3035d5b0b9..a4fcc101f4 100644 --- a/torchtune/models/__init__.py +++ b/torchtune/models/__init__.py @@ -4,46 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable, Union - -import torch -from torch.nn import Module - from torchtune.models import llama2 -from torchtune.utils import get_device - -ALL_MODELS = {"llama2_7b": llama2.llama2_7b, "lora_llama2_7b": llama2.lora_llama2_7b} -ALL_TOKENIZERS = {"llama2_tokenizer": llama2.llama2_tokenizer} - - -def get_model(name: str, device: Union[str, torch.device], **kwargs) -> Module: - """Get known supported models by name""" - if name in ALL_MODELS: - with get_device(device): - model = ALL_MODELS[name](**kwargs) - return model - else: - raise ValueError( - f"Model not recognized. Expected one of {ALL_MODELS}, received {name}" - ) - - -def get_tokenizer(name: str, **kwargs) -> Callable: - """Get known supported tokenizers by name""" - if name in ALL_TOKENIZERS: - return ALL_TOKENIZERS[name](**kwargs) - else: - raise ValueError( - f"Tokenizer not recognized. Expected one of {ALL_TOKENIZERS}, received {name}" - ) - - -def list_models(): - """List of availabe models supported by `get_model`""" - return list(ALL_MODELS) - - -def list_tokenizers(): - """List of availabe tokenizers supported by `get_tokenizer`""" - return list(ALL_TOKENIZERS) +__all__ = [ + "llama2", +] diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index ac011ccd3a..2603a3a060 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -4,12 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import torch - -from torch import nn -from torch.optim.lr_scheduler import LRScheduler -from torch.optim.optimizer import Optimizer - from .attention import CausalSelfAttention # noqa from .feed_forward import FeedForward # noqa from .kv_cache import KVCache # noqa @@ -30,79 +24,3 @@ "TransformerDecoder", "TransformerDecoderLayer", ] - - -def get_loss(loss: str) -> nn.Module: - """Returns a loss function from torch.nn. - - Args: - loss (str): name of the loss function. - - Returns: - nn.Module: loss function. - - Raises: - ValueError: if the loss is not a valid loss from torch.nn. - """ - try: - return getattr(nn, loss)() - except AttributeError as e: - raise ValueError(f"{loss} is not a valid loss from torch.nn") from e - - -def get_optimizer( - optimizer: str, model: torch.nn.Module, lr: float, weight_decay: float = 0.0 -) -> Optimizer: - """Returns an optimizer function from torch.optim. - - Args: - optimizer (str): name of the optimizer. - model (torch.nn.Module): model to optimize. - lr (float): learning rate. - weight_decay (float): weight decay for optimizer. Default is 0.0. - - Returns: - Optimizer: optimizer function. - - Raises: - ValueError: if the optimizer is not a valid optimizer from torch.optim. - """ - try: - trainable_params = [p for n, p in model.named_parameters() if p.requires_grad] - return getattr(torch.optim, optimizer)( - trainable_params, lr=lr, weight_decay=weight_decay - ) - except AttributeError as e: - raise ValueError( - f"{optimizer} is not a valid optimizer from torch.optim" - ) from e - - -ALL_LR_SCHEDULERS = {"cosine_with_warmup": get_cosine_schedule_with_warmup} - - -def get_lr_scheduler( - lr_scheduler: str, optimizer: torch.optim.Optimizer, **kwargs -) -> LRScheduler: - """Returns an optimizer function from torch.optim. - - Args: - lr_scheduler (str): name of the learning rate scheduler. - optimizer (torch.optim.Optimizer): optimizer. - **kwargs: additional arguments to pass to the learning rate scheduler. - - Returns: - LRScheduler: learning rate scheduler. - - Raises: - ValueError: if the lr scheduler is not a valid optimizer from torch.optim. - """ - try: - if lr_scheduler in ALL_LR_SCHEDULERS: - return ALL_LR_SCHEDULERS[lr_scheduler](optimizer, **kwargs) - else: - getattr(torch.optim.lr_scheduler, lr_scheduler)(optimizer, **kwargs) - except AttributeError as e: - raise ValueError( - f"{lr_scheduler} is not a valid learning rate scheduler from torch.optim.lr_scheduler or torchtune" - ) from e diff --git a/torchtune/optim.py b/torchtune/optim.py deleted file mode 100644 index 65f89eafc2..0000000000 --- a/torchtune/optim.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import torch -from torch.optim.optimizer import Optimizer - - -def get_optimizer( - optimizer: str, model: torch.nn.Module, lr: float, weight_decay: float = 0.0 -) -> Optimizer: - """Returns an optimizer function from torch.optim. - - Args: - optimizer (str): name of the optimizer. - model (torch.nn.Module): model to optimize. - lr (float): learning rate. - weight_decay (float): weight decay for optimizer. Default is 0.0. - - Returns: - Optimizer: optimizer function. - - Raises: - ValueError: if the optimizer is not a valid optimizer from torch.optim. - """ - try: - trainable_params = [p for n, p in model.named_parameters() if p.requires_grad] - return getattr(torch.optim, optimizer)( - trainable_params, lr=lr, weight_decay=weight_decay - ) - except AttributeError as e: - raise ValueError( - f"{optimizer} is not a valid optimizer from torch.optim" - ) from e - - -# TODO convert to folder when we support tuning specific optimizers diff --git a/torchtune/utils/__init__.py b/torchtune/utils/__init__.py index 583e3039a6..a7ceaf5288 100644 --- a/torchtune/utils/__init__.py +++ b/torchtune/utils/__init__.py @@ -11,16 +11,13 @@ from .distributed import get_world_size_and_rank, init_distributed, wrap_fsdp from .logging import get_logger from .memory import set_activation_checkpointing -from .metric_logging import get_metric_logger, list_metric_loggers from .precision import get_autocast, get_dtype, get_gradient_scaler, list_dtypes from .seed import set_seed __all__ = [ - "list_metric_loggers", "save_checkpoint", "transform_opt_state_dict", "validate_checkpoint", - "get_metric_logger", "get_autocast", "get_device", "get_dtype", diff --git a/torchtune/utils/metric_logging.py b/torchtune/utils/metric_logging.py index f0a25a568b..9c6999ba08 100644 --- a/torchtune/utils/metric_logging.py +++ b/torchtune/utils/metric_logging.py @@ -8,7 +8,7 @@ import time from pathlib import Path -from typing import Dict, List, Mapping, Optional, Union +from typing import Mapping, Optional, Union from numpy import ndarray from torch import Tensor @@ -239,41 +239,3 @@ def close(self) -> None: if self._writer: self._writer.close() self._writer = None - - -ALL_METRIC_LOGGERS: Dict[str, "MetricLoggerInterface"] = { - "wandb": WandBLogger, - "tensorboard": TensorBoardLogger, - "stdout": StdoutLogger, - "disk": DiskLogger, -} - - -def list_metric_loggers() -> List[str]: - """List available metric loggers. - - Returns: - List[str]: list of available metric loggers - """ - return list(ALL_METRIC_LOGGERS.keys()) - - -def get_metric_logger(metric_logger_type: str, **kwargs) -> "MetricLoggerInterface": - """Get a metric logger based on provided arguments. - - Args: - metric_logger_type (str): name of the metric logger, options are "wandb", "tensorboard", "stdout", "disk". - **kwargs: additional arguments to pass to the metric logger - - Raises: - ValueError: If ``metric_logger`` str is unknown. - - Returns: - MetricLoggerInterface: metric logger - """ - if metric_logger_type not in ALL_METRIC_LOGGERS: - raise ValueError( - f"Metric logger not recognized. Expected one of {list_metric_loggers}, received {metric_logger_type}." - ) - - return ALL_METRIC_LOGGERS[metric_logger_type](**kwargs)