diff --git a/docs/source/examples/configs.rst b/docs/source/examples/configs.rst index 4217356181..8a6caa2bb4 100644 --- a/docs/source/examples/configs.rst +++ b/docs/source/examples/configs.rst @@ -219,9 +219,9 @@ For example, to run the :code:`full_finetune` recipe with custom model and token Overriding components ^^^^^^^^^^^^^^^^^^^^^ -If you would like to override a parameter in the config that has a :code:`_component_` -field, you can do so by assigning to the parameter name directly. Any nested fields -in the components can be overridden with dot notation. +If you would like to override a class or function in the config that is instantiated +via the :code:`_component_` field, you can do so by assigning to the parameter +name directly. Any nested fields in the components can be overridden with dot notation. .. code-block:: yaml diff --git a/recipes/alpaca_generate.py b/recipes/alpaca_generate.py index c6c85a9586..eb2cd8c28d 100644 --- a/recipes/alpaca_generate.py +++ b/recipes/alpaca_generate.py @@ -4,8 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import sys - import torch from omegaconf import DictConfig diff --git a/tests/torchtune/config/test_utils.py b/tests/torchtune/config/test_utils.py index 60669e8f53..2ab0b1fe9e 100644 --- a/tests/torchtune/config/test_utils.py +++ b/tests/torchtune/config/test_utils.py @@ -67,3 +67,31 @@ def test_merge_yaml_and_cli_args(self, mock_load): assert conf.d == 6, f"d == {conf.d}, not 6 as set in overrides." assert conf.e == 7, f"e == {conf.e}, not 7 as set in overrides." mock_load.assert_called_once() + + yaml_args, cli_args = parser.parse_known_args( + [ + "--config", + "test.yaml", + "b=5", # Test overriding component path but keeping other kwargs + ] + ) + conf = _merge_yaml_and_cli_args(yaml_args, cli_args) + assert ( + conf.b._component_ == 5 + ), f"b == {conf.b._component_}, not 5 as set in overrides." + assert conf.b.c == 3, f"b.c == {conf.b.c}, not 3 as set in the config." + assert mock_load.call_count == 2 + + yaml_args, cli_args = parser.parse_known_args( + [ + "--config", + "test.yaml", + "b.c=5", # Test overriding kwarg but keeping component path + ] + ) + conf = _merge_yaml_and_cli_args(yaml_args, cli_args) + assert ( + conf.b._component_ == 2 + ), f"b == {conf.b._component_}, not 2 as set in the config." + assert conf.b.c == 5, f"b.c == {conf.b.c}, not 5 as set in overrides." + assert mock_load.call_count == 3 diff --git a/torchtune/config/_utils.py b/torchtune/config/_utils.py index 05d59fe0b4..c071c7d194 100644 --- a/torchtune/config/_utils.py +++ b/torchtune/config/_utils.py @@ -7,7 +7,7 @@ from argparse import Namespace from importlib import import_module from types import ModuleType -from typing import Any, List +from typing import Any, Dict, List, Union from omegaconf import DictConfig, OmegaConf