diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 1b9f461b6..6246d6672 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -144,7 +144,7 @@ def _resolve_optional(type_: Any) -> Tuple[bool, Any]: args = type_.__args__ if len(args) == 2 and args[1] == type(None): # noqa E721 return True, args[0] - if type_ is Any: + if type_ is Any: # lgtm [py/comparison-using-is] return True, Any return False, type_ diff --git a/omegaconf/ir.py b/omegaconf/ir.py new file mode 100644 index 000000000..ce71ebf2f --- /dev/null +++ b/omegaconf/ir.py @@ -0,0 +1,90 @@ +import dataclasses # lgtm [py/import-and-import-from] +from dataclasses import dataclass +from typing import Any, Optional, get_type_hints + +from omegaconf._utils import ( + _resolve_forward, + _resolve_optional, + get_type_of, + is_attr_class, + is_dataclass, +) + + +@dataclass +class IR: + pass + + +@dataclass +class IRNode(IR): + name: Optional[str] + type: Any + val: Any + opt: bool + + +def get_dataclass_ir(obj: Any) -> IRNode: + from omegaconf.omegaconf import MISSING + + resolved_hints = get_type_hints(get_type_of(obj)) + assert is_dataclass(obj) + obj_type = get_type_of(obj) + children = [] + for fld in dataclasses.fields(obj): + name = fld.name + opt, type_ = _resolve_optional(resolved_hints[name]) + type_ = _resolve_forward(type_, fld.__module__) + + if hasattr(obj, name): + value = getattr(obj, name) + if value == dataclasses.MISSING: + value = MISSING + else: + if fld.default_factory == dataclasses.MISSING: # type: ignore + value = MISSING + else: + value = fld.default_factory() # type: ignore + ir = IRNode(name=name, type=type_, opt=opt, val=value) + children.append(ir) + + return IRNode(name=None, val=children, type=obj_type, opt=False) + + +def get_attr_ir(obj: Any) -> IRNode: + import attr + import attr._make + + from omegaconf.omegaconf import MISSING + + resolved_hints = get_type_hints(get_type_of(obj)) + assert is_attr_class(obj) + obj_type = get_type_of(obj) + children = [] + for name, attrib in attr.fields_dict(obj).items(): + # for fld in dataclasses.fields(obj): + # name = fld.name + opt, type_ = _resolve_optional(resolved_hints[name]) + type_ = _resolve_forward(type_, obj_type.__module__) + + assert not hasattr(obj, name) # no test coverage for this case yet + if attrib.default == attr.NOTHING: + value = MISSING + elif isinstance(attrib.default, attr._make.Factory): + assert not attrib.default.takes_self, "not supported yet" + value = attrib.default.factory() + else: + value = attrib.default + ir = IRNode(name=name, type=type_, opt=opt, val=value) + children.append(ir) + + return IRNode(name=None, val=children, type=obj_type, opt=False) + + +def get_structured_config_ir(obj: Any) -> IRNode: + if is_dataclass(obj): + return get_dataclass_ir(obj) + elif is_attr_class(obj): + return get_attr_ir(obj) + else: + raise ValueError(f"Unsupported type: {type(obj).__name__}") diff --git a/tests/ir/__init__.py b/tests/ir/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/ir/data/__init__.py b/tests/ir/data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/ir/data/attr.py b/tests/ir/data/attr.py new file mode 100644 index 000000000..3008d032d --- /dev/null +++ b/tests/ir/data/attr.py @@ -0,0 +1,34 @@ +import attr +from attr import NOTHING as backend_MISSING + +from omegaconf import MISSING + + +@attr.s(auto_attribs=True) +class User: + name: str + age: int + + +@attr.s(auto_attribs=True) +class UserWithMissing: + name: str = MISSING + age: int = MISSING + + +@attr.s(auto_attribs=True) +class UserWithBackendMissing: + name: str = backend_MISSING # type: ignore + age: int = backend_MISSING # type: ignore + + +@attr.s(auto_attribs=True) +class UserWithDefault: + name: str = "bond" + age: int = 7 + + +@attr.s(auto_attribs=True) +class UserWithDefaultFactory: + name: str = attr.ib(factory=lambda: "bond") + age: int = attr.ib(factory=lambda: 7) diff --git a/tests/ir/data/dataclass.py b/tests/ir/data/dataclass.py new file mode 100644 index 000000000..f3c210d88 --- /dev/null +++ b/tests/ir/data/dataclass.py @@ -0,0 +1,34 @@ +from dataclasses import MISSING as backend_MISSING +from dataclasses import dataclass, field + +from omegaconf import MISSING + + +@dataclass +class User: + name: str + age: int + + +@dataclass +class UserWithMissing: + name: str = MISSING + age: int = MISSING + + +@dataclass +class UserWithBackendMissing: + name: str = backend_MISSING # type: ignore + age: int = backend_MISSING # type: ignore + + +@dataclass +class UserWithDefault: + name: str = "bond" + age: int = 7 + + +@dataclass +class UserWithDefaultFactory: + name: str = field(default_factory=lambda: "bond") + age: int = field(default_factory=lambda: 7) diff --git a/tests/ir/test_ir.py b/tests/ir/test_ir.py new file mode 100644 index 000000000..4aaa9c650 --- /dev/null +++ b/tests/ir/test_ir.py @@ -0,0 +1,107 @@ +import copy +from importlib import import_module +from typing import Any + +from pytest import fixture, raises + +from omegaconf import MISSING +from omegaconf.ir import IRNode, get_structured_config_ir + + +def resolve_types(module: Any, ir: IRNode) -> None: + if isinstance(ir.type, str): + ir.type = getattr(module, ir.type) + + if isinstance(ir.val, list): + for c in ir.val: + resolve_types(module, c) + + +@fixture( + params=["tests.ir.data.dataclass", "tests.ir.data.attr"], + ids=lambda x: x.split(".")[-1], # type: ignore +) +def module(request: Any) -> Any: + return import_module(request.param) + + +@fixture( + params=[ + ( + "User", + IRNode( + name=None, + type="User", + opt=False, + val=[ + IRNode(name="name", type=str, opt=False, val=MISSING), + IRNode(name="age", type=int, opt=False, val=MISSING), + ], + ), + ), + ( + "UserWithMissing", + IRNode( + name=None, + type="UserWithMissing", + opt=False, + val=[ + IRNode(name="name", type=str, opt=False, val=MISSING), + IRNode(name="age", type=int, opt=False, val=MISSING), + ], + ), + ), + ( + "UserWithBackendMissing", + IRNode( + name=None, + type="UserWithBackendMissing", + opt=False, + val=[ + IRNode(name="name", type=str, opt=False, val=MISSING), + IRNode(name="age", type=int, opt=False, val=MISSING), + ], + ), + ), + ( + "UserWithDefault", + IRNode( + name=None, + type="UserWithDefault", + opt=False, + val=[ + IRNode(name="name", type=str, opt=False, val="bond"), + IRNode(name="age", type=int, opt=False, val=7), + ], + ), + ), + ( + "UserWithDefaultFactory", + IRNode( + name=None, + type="UserWithDefaultFactory", + opt=False, + val=[ + IRNode(name="name", type=str, opt=False, val="bond"), + IRNode(name="age", type=int, opt=False, val=7), + ], + ), + ), + ], + ids=lambda x: x[0], # type: ignore +) +def tested_type(module: Any, request: Any) -> Any: + name = request.param[0] + expected = copy.deepcopy(request.param[1]) + resolve_types(module, expected) + return {"type": getattr(module, name), "expected": expected} + + +def test_get_dataclass_ir(tested_type: Any) -> None: + assert get_structured_config_ir(tested_type["type"]) == tested_type["expected"] + + +def test_get_structured_config_ir_rejects_nonstructured() -> None: + """`get_structured_config_ir` should reject input that is not structured""" + with raises(ValueError): + get_structured_config_ir(object()) diff --git a/tests/test_utils.py b/tests/test_utils.py index 207e54230..f53f49e6b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -212,7 +212,7 @@ class Foo: assert _utils.is_dataclass(Foo) assert _utils.is_dataclass(Foo()) assert not _utils.is_dataclass(10) - + # TODO: dataclasses are now mandatory, clean this up. mocker.patch("omegaconf._utils.dataclasses", None) assert not _utils.is_dataclass(10)