diff --git a/hydra/_internal/utils.py b/hydra/_internal/utils.py index e133be6bfd6..1b850aeeee3 100644 --- a/hydra/_internal/utils.py +++ b/hydra/_internal/utils.py @@ -586,7 +586,9 @@ def _get_kwargs(config: Union[ObjectConf, DictConfig], **kwargs: Any) -> Any: config_overrides = {} passthrough = {} for k, v in kwargs.items(): - if k in params and get_ref_type(params, k) is not Any: + if k in params and not ( + get_ref_type(params, k) is Any and OmegaConf.is_missing(params, k) + ): config_overrides[k] = v else: passthrough[k] = v @@ -595,10 +597,11 @@ def _get_kwargs(config: Union[ObjectConf, DictConfig], **kwargs: Any) -> Any: with read_write(params): params.merge_with(config_overrides) - for k, v in params.items_ex(resolve=False): + for k in params.keys(): if k == "_target_": continue - final_kwargs[k] = v + if k not in passthrough: + final_kwargs[k] = params[k] for k, v in passthrough.items(): final_kwargs[k] = v diff --git a/news/1001.bugfix b/news/1001.bugfix new file mode 100644 index 00000000000..02affd64bca --- /dev/null +++ b/news/1001.bugfix @@ -0,0 +1 @@ +Fixed interaction between interpolation and instantiate diff --git a/tests/__init__.py b/tests/__init__.py index 8e6a47929ae..4fa533e6d75 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -28,11 +28,10 @@ def module_function(x: int) -> int: @dataclass class AClass: - def __init__(self, a: Any, b: Any, c: Any, d: Any = "default_value") -> None: - self.a = a - self.b = b - self.c = c - self.d = d + a: Any + b: Any + c: Any + d: Any = "default_value" @staticmethod def static_method(z: int) -> int: @@ -43,16 +42,11 @@ def static_method(z: int) -> int: class UntypedPassthroughConf: _target_: str = "tests.UntypedPassthroughClass" a: Any = MISSING - b: Any = 10 +@dataclass class UntypedPassthroughClass: - def __init__(self, a: Any, b: Any) -> None: - self.a = a - self.b = b - - def __eq__(self, other: Any) -> Any: - return self.a == other.a and self.b == other.b + a: Any # Type not legal in a config diff --git a/tests/test_utils.py b/tests/test_utils.py index 1d74d41afbf..485e779fd83 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -52,6 +52,11 @@ def test_get_class(path: str, expected_type: type) -> None: assert utils.get_class(path) == expected_type +def test_a_class_eq() -> None: + assert AClass(a=10, b=20, c=30, d=40) != AClass(a=11, b=12, c=13, d=14) + assert AClass(a=10, b=20, c=30, d=40) == AClass(a=10, b=20, c=30, d=40) + + @pytest.mark.parametrize( # type: ignore "input_conf, passthrough, expected", [ @@ -143,8 +148,8 @@ def test_get_class(path: str, expected_type: type) -> None: ), pytest.param( UntypedPassthroughConf, - {"a": IllegalType(), "b": IllegalType()}, - UntypedPassthroughClass(a=IllegalType(), b=IllegalType()), + {"a": IllegalType()}, + UntypedPassthroughClass(a=IllegalType()), id="untyped_passthrough", ), ], @@ -232,10 +237,14 @@ def test_class_instantiate_objectconf_pass_omegaconf_node() -> Any: def test_class_instantiate_omegaconf_node() -> Any: conf = OmegaConf.structured( - {"_target_": "tests.AClass", "b": 200, "c": {"x": 10, "y": "${params.b}"}} + { + "_target_": "tests.AClass", + "b": 200, + "c": {"x": 10, "y": "${b}"}, + } ) - obj = utils.instantiate(conf, **{"a": 10, "d": AnotherClass(99)}) - assert obj == AClass(10, 200, {"x": 10, "y": 200}, AnotherClass(99)) + obj = utils.instantiate(conf, a=10, d=AnotherClass(99)) + assert obj == AClass(a=10, b=200, c={"x": 10, "y": 200}, d=AnotherClass(99)) assert OmegaConf.is_config(obj.c)