From 0612c1c51562cf30a47a493d2fafbd28f4a0aa7d Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Tue, 7 Nov 2023 00:56:16 -0600 Subject: [PATCH] Fix interpolation to structured config from within typed list This PR closes #1005, fixing a bug where a validation error was raised when an interpolation points to a structured config from within a typed list. ```python from dataclasses import dataclass, field from omegaconf import MISSING, OmegaConf @dataclass class User: name: str = MISSING @dataclass class Config: user: User = User("John") users: list[User] = field(default_factory=lambda: ["${user}"]) cfg = OmegaConf.structured(Config) OmegaConf.resolve(cfg) print(cfg) ``` BEFORE ------ ```text $ python repro.py $ python repro.py Traceback (most recent call last): ... omegaconf.errors.ValidationError: Invalid type assigned: str is not a subclass of User. value: ${user} full_key: users[0] reference_type=List[User] object_type=list ``` AFTER ----- ```text $ python repro.py {'user': {'name': 'John'}, 'users': [{'name': 'John'}]} ``` --- omegaconf/listconfig.py | 2 ++ tests/structured_conf/data/attr_classes.py | 2 ++ tests/structured_conf/data/dataclasses.py | 2 ++ tests/structured_conf/data/dataclasses_pre_311.py | 2 ++ tests/structured_conf/test_structured_basic.py | 4 ++++ 5 files changed, 12 insertions(+) diff --git a/omegaconf/listconfig.py b/omegaconf/listconfig.py index f9f430c01..20a8f8169 100644 --- a/omegaconf/listconfig.py +++ b/omegaconf/listconfig.py @@ -106,6 +106,8 @@ def _validate_set(self, key: Any, value: Any) -> None: vk = get_value_kind(value) if vk == ValueKind.MANDATORY_MISSING: return + if vk == ValueKind.INTERPOLATION: + return else: is_optional, target_type = _resolve_optional(self._metadata.element_type) value_type = OmegaConf.get_type(value) diff --git a/tests/structured_conf/data/attr_classes.py b/tests/structured_conf/data/attr_classes.py index f0b32d02c..ba4c15e86 100644 --- a/tests/structured_conf/data/attr_classes.py +++ b/tests/structured_conf/data/attr_classes.py @@ -76,6 +76,8 @@ class OptionalUser: class InterpolationToUser: user: User = User("Bond", 7) admin: User = II("user") + admin_list: List[User] = [II("user")] + admin_dict: Dict[str, User] = {"bond": II("user")} @attr.s(auto_attribs=True) diff --git a/tests/structured_conf/data/dataclasses.py b/tests/structured_conf/data/dataclasses.py index b79640794..88607cb12 100644 --- a/tests/structured_conf/data/dataclasses.py +++ b/tests/structured_conf/data/dataclasses.py @@ -77,6 +77,8 @@ class OptionalUser: class InterpolationToUser: user: User = field(default_factory=lambda: User("Bond", 7)) admin: User = II("user") + admin_list: List[User] = field(default_factory=lambda: [II("user")]) + admin_dict: Dict[str, User] = field(default_factory=lambda: {"bond": II("user")}) @dataclass diff --git a/tests/structured_conf/data/dataclasses_pre_311.py b/tests/structured_conf/data/dataclasses_pre_311.py index 1401d4324..c045e20d7 100644 --- a/tests/structured_conf/data/dataclasses_pre_311.py +++ b/tests/structured_conf/data/dataclasses_pre_311.py @@ -77,6 +77,8 @@ class OptionalUser: class InterpolationToUser: user: User = User("Bond", 7) admin: User = II("user") + admin_list: List[User] = field(default_factory=lambda: [II("user")]) + admin_dict: Dict[str, User] = field(default_factory=lambda: {"bond": II("user")}) @dataclass diff --git a/tests/structured_conf/test_structured_basic.py b/tests/structured_conf/test_structured_basic.py index 4fe4f168e..768908e25 100644 --- a/tests/structured_conf/test_structured_basic.py +++ b/tests/structured_conf/test_structured_basic.py @@ -247,6 +247,10 @@ def test_interpolation_to_structured(self, module: Any, resolve: bool) -> None: OmegaConf.resolve(cfg) assert OmegaConf.get_type(cfg.admin) is module.User assert cfg.admin == {"name": "Bond", "age": 7} + assert OmegaConf.get_type(cfg.admin_list[0]) is module.User + assert cfg.admin_list == [{"name": "Bond", "age": 7}] + assert OmegaConf.get_type(cfg.admin_dict["bond"]) is module.User + assert cfg.admin_dict == {"bond": {"name": "Bond", "age": 7}} class TestMissing: def test_missing1(self, module: Any) -> None: