Skip to content

Commit

Permalink
Fix to_structured() for attrs classes with fields having leading "_". (
Browse files Browse the repository at this point in the history
  • Loading branch information
bzczb authored Jan 18, 2024
1 parent d888120 commit 5525cd5
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 12 deletions.
16 changes: 14 additions & 2 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,15 +482,27 @@ def is_structured_config_frozen(obj: Any) -> bool:
return False


def get_structured_config_init_field_names(obj: Any) -> List[str]:
def _find_attrs_init_field_alias(field: Any) -> str:
# New versions of attrs, after 22.2.0, have the alias explicitly defined.
# Previous versions implicitly strip the underscore in the init parameter.
if hasattr(field, "alias"):
assert isinstance(field.alias, str)
return field.alias
else: # pragma: no cover
assert isinstance(field.name, str)
return field.name.lstrip("_")


def get_structured_config_init_field_aliases(obj: Any) -> Dict[str, str]:
fields: Union[List["dataclasses.Field[Any]"], List["attr.Attribute[Any]"]]
if is_dataclass(obj):
fields = get_dataclass_fields(obj)
return {f.name: f.name for f in fields if f.init}
elif is_attr_class(obj):
fields = get_attr_class_fields(obj)
return {f.name: _find_attrs_init_field_alias(f) for f in fields if f.init}
else:
raise ValueError(f"Unsupported type: {type(obj).__name__}")
return [f.name for f in fields if f.init]


def get_structured_config_data(
Expand Down
10 changes: 5 additions & 5 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
_valid_dict_key_annotation_type,
format_and_raise,
get_structured_config_data,
get_structured_config_init_field_names,
get_structured_config_init_field_aliases,
get_type_of,
get_value_kind,
is_container_annotation,
Expand Down Expand Up @@ -726,7 +726,7 @@ def _to_object(self) -> Any:

object_type = self._metadata.object_type
assert is_structured_config(object_type)
init_field_names = set(get_structured_config_init_field_names(object_type))
init_field_aliases = get_structured_config_init_field_aliases(object_type)

init_field_items: Dict[str, Any] = {}
non_init_field_items: Dict[str, Any] = {}
Expand All @@ -739,7 +739,7 @@ def _to_object(self) -> Any:
except InterpolationResolutionError as e:
self._format_and_raise(key=k, value=None, cause=e)
if node._is_missing():
if k not in init_field_names:
if k not in init_field_aliases:
continue # MISSING is ignored for init=False fields
self._format_and_raise(
key=k,
Expand All @@ -753,8 +753,8 @@ def _to_object(self) -> Any:
else:
v = node._value()

if k in init_field_names:
init_field_items[k] = v
if k in init_field_aliases:
init_field_items[init_field_aliases[k]] = v
else:
non_init_field_items[k] = v

Expand Down
6 changes: 6 additions & 0 deletions tests/structured_conf/data/attr_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,12 @@ def __attrs_post_init__(self) -> None:
self.post_initialized = "set_by_post_init"


@attr.s(auto_attribs=True)
class LeadingUnderscoreFields:
_foo: str = "x"
_bar: str = "y"


class NestedContainers:
@attr.s(auto_attribs=True)
class ListOfLists:
Expand Down
6 changes: 6 additions & 0 deletions tests/structured_conf/data/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,12 @@ def __post_init__(self) -> None:
self.post_initialized = "set_by_post_init"


@dataclass
class LeadingUnderscoreFields:
_foo: str = "x"
_bar: str = "y"


class NestedContainers:
@dataclass
class ListOfLists:
Expand Down
6 changes: 6 additions & 0 deletions tests/structured_conf/data/dataclasses_pre_311.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,12 @@ def __post_init__(self) -> None:
self.post_initialized = "set_by_post_init"


@dataclass
class LeadingUnderscoreFields:
_foo: str = "x"
_bar: str = "y"


class NestedContainers:
@dataclass
class ListOfLists:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_to_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,14 @@ def test_ignore_metadata_with_default_args(self, module: Any) -> None:
data = OmegaConf.to_object(cfg)
assert data == module.HasIgnoreMetadataWithDefault(1, 4)

def test_leading_underscore_fields(self, module: Any) -> None:
cfg = OmegaConf.structured(module.LeadingUnderscoreFields)
container = OmegaConf.to_container(
cfg, structured_config_mode=SCMode.INSTANTIATE
)
assert isinstance(container, module.LeadingUnderscoreFields)
assert container._foo == "x" and container._bar == "y"


class TestEnumToStr:
"""Test the `enum_to_str` argument to the `OmegaConf.to_container function`"""
Expand Down
12 changes: 7 additions & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,15 @@ def test_get_structured_config_data_throws_ValueError(self) -> None:
"test_cls_or_obj",
[_TestDataclass, _TestDataclass(), _TestAttrsClass, _TestAttrsClass()],
)
def test_get_structured_config_field_names(self, test_cls_or_obj: Any) -> None:
field_names = _utils.get_structured_config_init_field_names(test_cls_or_obj)
assert field_names == ["x", "s", "b", "p", "d", "f", "e", "list1", "dict1"]
def test_get_structured_config_field_aliases(self, test_cls_or_obj: Any) -> None:
field_names = _utils.get_structured_config_init_field_aliases(test_cls_or_obj)
compare = ["x", "s", "b", "p", "d", "f", "e", "list1", "dict1"]
assert list(field_names.keys()) == compare
assert list(field_names.values()) == compare

def test_get_structured_config_field_names_throws_ValueError(self) -> None:
def test_get_structured_config_field_aliases_throws_ValueError(self) -> None:
with raises(ValueError):
_utils.get_structured_config_init_field_names("invalid")
_utils.get_structured_config_init_field_aliases("invalid")


@mark.parametrize(
Expand Down

0 comments on commit 5525cd5

Please sign in to comment.