Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA committed Mar 7, 2024
1 parent 8e03425 commit 4ca0a48
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 6 deletions.
6 changes: 3 additions & 3 deletions docs/source/examples/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ For example, to run the :code:`full_finetune` recipe with custom model and token
Overriding components
^^^^^^^^^^^^^^^^^^^^^
If you would like to override a parameter in the config that has a :code:`_component_`
field, you can do so by assigning to the parameter name directly. Any nested fields
in the components can be overridden with dot notation.
If you would like to override a class or function in the config that is instantiated
via the :code:`_component_` field, you can do so by assigning to the parameter
name directly. Any nested fields in the components can be overridden with dot notation.
.. code-block:: yaml
Expand Down
2 changes: 0 additions & 2 deletions recipes/alpaca_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import sys

import torch
from omegaconf import DictConfig

Expand Down
28 changes: 28 additions & 0 deletions tests/torchtune/config/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,31 @@ def test_merge_yaml_and_cli_args(self, mock_load):
assert conf.d == 6, f"d == {conf.d}, not 6 as set in overrides."
assert conf.e == 7, f"e == {conf.e}, not 7 as set in overrides."
mock_load.assert_called_once()

yaml_args, cli_args = parser.parse_known_args(
[
"--config",
"test.yaml",
"b=5", # Test overriding component path but keeping other kwargs
]
)
conf = _merge_yaml_and_cli_args(yaml_args, cli_args)
assert (
conf.b._component_ == 5
), f"b == {conf.b._component_}, not 5 as set in overrides."
assert conf.b.c == 3, f"b.c == {conf.b.c}, not 3 as set in the config."
assert mock_load.call_count == 2

yaml_args, cli_args = parser.parse_known_args(
[
"--config",
"test.yaml",
"b.c=5", # Test overriding kwarg but keeping component path
]
)
conf = _merge_yaml_and_cli_args(yaml_args, cli_args)
assert (
conf.b._component_ == 2
), f"b == {conf.b._component_}, not 2 as set in the config."
assert conf.b.c == 5, f"b.c == {conf.b.c}, not 5 as set in overrides."
assert mock_load.call_count == 3
2 changes: 1 addition & 1 deletion torchtune/config/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from argparse import Namespace
from importlib import import_module
from types import ModuleType
from typing import Any, List
from typing import Any, Dict, List, Union

from omegaconf import DictConfig, OmegaConf

Expand Down

0 comments on commit 4ca0a48

Please sign in to comment.