Skip to content

Commit

Permalink
Fixed interpolation in instantiate config/params
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Sep 22, 2020
1 parent 94bbaaa commit 65bbb03
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 20 deletions.
9 changes: 6 additions & 3 deletions hydra/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,9 @@ def _get_kwargs(config: Union[ObjectConf, DictConfig], **kwargs: Any) -> Any:
config_overrides = {}
passthrough = {}
for k, v in kwargs.items():
if k in params and get_ref_type(params, k) is not Any:
if k in params and not (
get_ref_type(params, k) is Any and OmegaConf.is_missing(params, k)
):
config_overrides[k] = v
else:
passthrough[k] = v
Expand All @@ -595,10 +597,11 @@ def _get_kwargs(config: Union[ObjectConf, DictConfig], **kwargs: Any) -> Any:
with read_write(params):
params.merge_with(config_overrides)

for k, v in params.items_ex(resolve=False):
for k in params.keys():
if k == "_target_":
continue
final_kwargs[k] = v
if k not in passthrough:
final_kwargs[k] = params[k]

for k, v in passthrough.items():
final_kwargs[k] = v
Expand Down
1 change: 1 addition & 0 deletions news/1001.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed interaction between interpolation and instantiate
18 changes: 6 additions & 12 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@ def module_function(x: int) -> int:

@dataclass
class AClass:
def __init__(self, a: Any, b: Any, c: Any, d: Any = "default_value") -> None:
self.a = a
self.b = b
self.c = c
self.d = d
a: Any
b: Any
c: Any
d: Any = "default_value"

@staticmethod
def static_method(z: int) -> int:
Expand All @@ -43,16 +42,11 @@ def static_method(z: int) -> int:
class UntypedPassthroughConf:
_target_: str = "tests.UntypedPassthroughClass"
a: Any = MISSING
b: Any = 10


@dataclass
class UntypedPassthroughClass:
def __init__(self, a: Any, b: Any) -> None:
self.a = a
self.b = b

def __eq__(self, other: Any) -> Any:
return self.a == other.a and self.b == other.b
a: Any


# Type not legal in a config
Expand Down
19 changes: 14 additions & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ def test_get_class(path: str, expected_type: type) -> None:
assert utils.get_class(path) == expected_type


def test_a_class_eq() -> None:
assert AClass(a=10, b=20, c=30, d=40) != AClass(a=11, b=12, c=13, d=14)
assert AClass(a=10, b=20, c=30, d=40) == AClass(a=10, b=20, c=30, d=40)


@pytest.mark.parametrize( # type: ignore
"input_conf, passthrough, expected",
[
Expand Down Expand Up @@ -143,8 +148,8 @@ def test_get_class(path: str, expected_type: type) -> None:
),
pytest.param(
UntypedPassthroughConf,
{"a": IllegalType(), "b": IllegalType()},
UntypedPassthroughClass(a=IllegalType(), b=IllegalType()),
{"a": IllegalType()},
UntypedPassthroughClass(a=IllegalType()),
id="untyped_passthrough",
),
],
Expand Down Expand Up @@ -232,10 +237,14 @@ def test_class_instantiate_objectconf_pass_omegaconf_node() -> Any:

def test_class_instantiate_omegaconf_node() -> Any:
conf = OmegaConf.structured(
{"_target_": "tests.AClass", "b": 200, "c": {"x": 10, "y": "${params.b}"}}
{
"_target_": "tests.AClass",
"b": 200,
"c": {"x": 10, "y": "${b}"},
}
)
obj = utils.instantiate(conf, **{"a": 10, "d": AnotherClass(99)})
assert obj == AClass(10, 200, {"x": 10, "y": 200}, AnotherClass(99))
obj = utils.instantiate(conf, a=10, d=AnotherClass(99))
assert obj == AClass(a=10, b=200, c={"x": 10, "y": 200}, d=AnotherClass(99))
assert OmegaConf.is_config(obj.c)


Expand Down

0 comments on commit 65bbb03

Please sign in to comment.