From 5525cd5b0659311e9a2905bca3cac2fcdb5189dc Mon Sep 17 00:00:00 2001 From: vermont <114126196+bzczb@users.noreply.github.com> Date: Thu, 18 Jan 2024 07:41:13 -0500 Subject: [PATCH] Fix to_structured() for attrs classes with fields having leading "_". (#1135) --- omegaconf/_utils.py | 16 ++++++++++++++-- omegaconf/dictconfig.py | 10 +++++----- tests/structured_conf/data/attr_classes.py | 6 ++++++ tests/structured_conf/data/dataclasses.py | 6 ++++++ .../structured_conf/data/dataclasses_pre_311.py | 6 ++++++ tests/test_to_container.py | 8 ++++++++ tests/test_utils.py | 12 +++++++----- 7 files changed, 52 insertions(+), 12 deletions(-) diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 2575f0452..c630d10ba 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -482,15 +482,27 @@ def is_structured_config_frozen(obj: Any) -> bool: return False -def get_structured_config_init_field_names(obj: Any) -> List[str]: +def _find_attrs_init_field_alias(field: Any) -> str: + # New versions of attrs, after 22.2.0, have the alias explicitly defined. + # Previous versions implicitly strip the underscore in the init parameter. + if hasattr(field, "alias"): + assert isinstance(field.alias, str) + return field.alias + else: # pragma: no cover + assert isinstance(field.name, str) + return field.name.lstrip("_") + + +def get_structured_config_init_field_aliases(obj: Any) -> Dict[str, str]: fields: Union[List["dataclasses.Field[Any]"], List["attr.Attribute[Any]"]] if is_dataclass(obj): fields = get_dataclass_fields(obj) + return {f.name: f.name for f in fields if f.init} elif is_attr_class(obj): fields = get_attr_class_fields(obj) + return {f.name: _find_attrs_init_field_alias(f) for f in fields if f.init} else: raise ValueError(f"Unsupported type: {type(obj).__name__}") - return [f.name for f in fields if f.init] def get_structured_config_data( diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index 12c1ebde9..11b2e701d 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -28,7 +28,7 @@ _valid_dict_key_annotation_type, format_and_raise, get_structured_config_data, - get_structured_config_init_field_names, + get_structured_config_init_field_aliases, get_type_of, get_value_kind, is_container_annotation, @@ -726,7 +726,7 @@ def _to_object(self) -> Any: object_type = self._metadata.object_type assert is_structured_config(object_type) - init_field_names = set(get_structured_config_init_field_names(object_type)) + init_field_aliases = get_structured_config_init_field_aliases(object_type) init_field_items: Dict[str, Any] = {} non_init_field_items: Dict[str, Any] = {} @@ -739,7 +739,7 @@ def _to_object(self) -> Any: except InterpolationResolutionError as e: self._format_and_raise(key=k, value=None, cause=e) if node._is_missing(): - if k not in init_field_names: + if k not in init_field_aliases: continue # MISSING is ignored for init=False fields self._format_and_raise( key=k, @@ -753,8 +753,8 @@ def _to_object(self) -> Any: else: v = node._value() - if k in init_field_names: - init_field_items[k] = v + if k in init_field_aliases: + init_field_items[init_field_aliases[k]] = v else: non_init_field_items[k] = v diff --git a/tests/structured_conf/data/attr_classes.py b/tests/structured_conf/data/attr_classes.py index ba4c15e86..93a9e2c30 100644 --- a/tests/structured_conf/data/attr_classes.py +++ b/tests/structured_conf/data/attr_classes.py @@ -653,6 +653,12 @@ def __attrs_post_init__(self) -> None: self.post_initialized = "set_by_post_init" +@attr.s(auto_attribs=True) +class LeadingUnderscoreFields: + _foo: str = "x" + _bar: str = "y" + + class NestedContainers: @attr.s(auto_attribs=True) class ListOfLists: diff --git a/tests/structured_conf/data/dataclasses.py b/tests/structured_conf/data/dataclasses.py index 88607cb12..27df38627 100644 --- a/tests/structured_conf/data/dataclasses.py +++ b/tests/structured_conf/data/dataclasses.py @@ -678,6 +678,12 @@ def __post_init__(self) -> None: self.post_initialized = "set_by_post_init" +@dataclass +class LeadingUnderscoreFields: + _foo: str = "x" + _bar: str = "y" + + class NestedContainers: @dataclass class ListOfLists: diff --git a/tests/structured_conf/data/dataclasses_pre_311.py b/tests/structured_conf/data/dataclasses_pre_311.py index c045e20d7..e40e56fc0 100644 --- a/tests/structured_conf/data/dataclasses_pre_311.py +++ b/tests/structured_conf/data/dataclasses_pre_311.py @@ -674,6 +674,12 @@ def __post_init__(self) -> None: self.post_initialized = "set_by_post_init" +@dataclass +class LeadingUnderscoreFields: + _foo: str = "x" + _bar: str = "y" + + class NestedContainers: @dataclass class ListOfLists: diff --git a/tests/test_to_container.py b/tests/test_to_container.py index f3c5f910a..04e64df8c 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -527,6 +527,14 @@ def test_ignore_metadata_with_default_args(self, module: Any) -> None: data = OmegaConf.to_object(cfg) assert data == module.HasIgnoreMetadataWithDefault(1, 4) + def test_leading_underscore_fields(self, module: Any) -> None: + cfg = OmegaConf.structured(module.LeadingUnderscoreFields) + container = OmegaConf.to_container( + cfg, structured_config_mode=SCMode.INSTANTIATE + ) + assert isinstance(container, module.LeadingUnderscoreFields) + assert container._foo == "x" and container._bar == "y" + class TestEnumToStr: """Test the `enum_to_str` argument to the `OmegaConf.to_container function`""" diff --git a/tests/test_utils.py b/tests/test_utils.py index 442969d3d..686f712d2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -371,13 +371,15 @@ def test_get_structured_config_data_throws_ValueError(self) -> None: "test_cls_or_obj", [_TestDataclass, _TestDataclass(), _TestAttrsClass, _TestAttrsClass()], ) - def test_get_structured_config_field_names(self, test_cls_or_obj: Any) -> None: - field_names = _utils.get_structured_config_init_field_names(test_cls_or_obj) - assert field_names == ["x", "s", "b", "p", "d", "f", "e", "list1", "dict1"] + def test_get_structured_config_field_aliases(self, test_cls_or_obj: Any) -> None: + field_names = _utils.get_structured_config_init_field_aliases(test_cls_or_obj) + compare = ["x", "s", "b", "p", "d", "f", "e", "list1", "dict1"] + assert list(field_names.keys()) == compare + assert list(field_names.values()) == compare - def test_get_structured_config_field_names_throws_ValueError(self) -> None: + def test_get_structured_config_field_aliases_throws_ValueError(self) -> None: with raises(ValueError): - _utils.get_structured_config_init_field_names("invalid") + _utils.get_structured_config_init_field_aliases("invalid") @mark.parametrize(