From 63111518f3a80eaa41cc04291b3c24fcfaa5542e Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Wed, 3 Feb 2021 23:31:18 -0800 Subject: [PATCH 1/5] wip --- omegaconf/ir.py | 50 +++++++++++++++++++++++++++++ tests/ir/__init__.py | 0 tests/ir/data/__init__.py | 0 tests/ir/data/attr.py | 15 +++++++++ tests/ir/data/dataclass.py | 15 +++++++++ tests/ir/test_ir.py | 65 ++++++++++++++++++++++++++++++++++++++ tests/test_utils.py | 2 +- 7 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 omegaconf/ir.py create mode 100644 tests/ir/__init__.py create mode 100644 tests/ir/data/__init__.py create mode 100644 tests/ir/data/attr.py create mode 100644 tests/ir/data/dataclass.py create mode 100644 tests/ir/test_ir.py diff --git a/omegaconf/ir.py b/omegaconf/ir.py new file mode 100644 index 000000000..279c7b324 --- /dev/null +++ b/omegaconf/ir.py @@ -0,0 +1,50 @@ +import dataclasses +from dataclasses import dataclass +from typing import Any, Optional, get_type_hints + +from omegaconf._utils import ( + _resolve_forward, + _resolve_optional, + get_type_of, + is_dataclass, +) + + +@dataclass +class IR: + pass + + +@dataclass +class IRNode(IR): + name: Optional[str] + type: Any + val: Any + opt: bool + + +def get_dataclass_ir(obj: Any) -> IRNode: + from omegaconf.omegaconf import MISSING + + resolved_hints = get_type_hints(get_type_of(obj)) + assert is_dataclass(obj) + obj_type = get_type_of(obj) + children = [] + for fld in dataclasses.fields(obj): + name = fld.name + opt, type_ = _resolve_optional(resolved_hints[name]) + type_ = _resolve_forward(type_, fld.__module__) + + if hasattr(obj, name): + value = getattr(obj, name) + if value == dataclasses.MISSING: + value = MISSING + else: + if fld.default_factory == dataclasses.MISSING: # type: ignore + value = MISSING + else: + value = fld.default_factory() # type: ignore + ir = IRNode(name=name, type=type_, opt=opt, val=value) + children.append(ir) + + return IRNode(name=None, val=children, type=obj_type, opt=False) diff --git a/tests/ir/__init__.py b/tests/ir/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/ir/data/__init__.py b/tests/ir/data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/ir/data/attr.py b/tests/ir/data/attr.py new file mode 100644 index 000000000..f45e0c49c --- /dev/null +++ b/tests/ir/data/attr.py @@ -0,0 +1,15 @@ +import attr + +from omegaconf import MISSING + + +@attr.s(auto_attribs=True) +class User: + name: str + age: int + + +@attr.s(auto_attribs=True) +class UserWithMissing: + name: str = MISSING + age: int = MISSING diff --git a/tests/ir/data/dataclass.py b/tests/ir/data/dataclass.py new file mode 100644 index 000000000..35eb874cd --- /dev/null +++ b/tests/ir/data/dataclass.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass + +from omegaconf import MISSING + + +@dataclass +class User: + name: str + age: int + + +@dataclass +class UserWithMissing: + name: str = MISSING + age: int = MISSING diff --git a/tests/ir/test_ir.py b/tests/ir/test_ir.py new file mode 100644 index 000000000..7c7068962 --- /dev/null +++ b/tests/ir/test_ir.py @@ -0,0 +1,65 @@ +import copy +from importlib import import_module +from typing import Any + +from pytest import fixture + +from omegaconf import MISSING +from omegaconf.ir import IRNode, get_dataclass_ir + + +def resolve_types(module: Any, ir: IRNode) -> None: + if isinstance(ir.type, str): + ir.type = getattr(module, ir.type) + + if isinstance(ir.val, list): + for c in ir.val: + resolve_types(module, c) + + +@fixture( + params=["tests.ir.data.dataclass", "tests.ir.data.attr"], + ids=lambda x: x.split(".")[-1], +) +def module(request: Any) -> Any: + return import_module(request.param) + + +@fixture( + params=[ + ( + "User", + IRNode( + name=None, + type="User", + opt=False, + val=[ + IRNode(name="name", type=str, opt=False, val=MISSING), + IRNode(name="age", type=int, opt=False, val=MISSING), + ], + ), + ), + ( + "UserWithMissing", + IRNode( + name=None, + type="UserWithMissing", + opt=False, + val=[ + IRNode(name="name", type=str, opt=False, val=MISSING), + IRNode(name="age", type=int, opt=False, val=MISSING), + ], + ), + ), + ], + ids=lambda x: x[0], +) +def tested_type(module: Any, request: Any) -> Any: + name = request.param[0] + expected = copy.deepcopy(request.param[1]) + resolve_types(module, expected) + return {"type": getattr(module, name), "expected": expected} + + +def test_get_dataclass_ir(tested_type: Any): + assert get_dataclass_ir(tested_type["type"]) == tested_type["expected"] diff --git a/tests/test_utils.py b/tests/test_utils.py index 207e54230..f53f49e6b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -212,7 +212,7 @@ class Foo: assert _utils.is_dataclass(Foo) assert _utils.is_dataclass(Foo()) assert not _utils.is_dataclass(10) - + # TODO: dataclasses are now mandatory, clean this up. mocker.patch("omegaconf._utils.dataclasses", None) assert not _utils.is_dataclass(10) From 1d5550967387ffa8f510753699b981928b457d0f Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 4 Feb 2021 14:33:12 -0600 Subject: [PATCH 2/5] fix mypy errors --- tests/ir/test_ir.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/ir/test_ir.py b/tests/ir/test_ir.py index 7c7068962..958ee10f2 100644 --- a/tests/ir/test_ir.py +++ b/tests/ir/test_ir.py @@ -19,7 +19,7 @@ def resolve_types(module: Any, ir: IRNode) -> None: @fixture( params=["tests.ir.data.dataclass", "tests.ir.data.attr"], - ids=lambda x: x.split(".")[-1], + ids=lambda x: x.split(".")[-1], # type: ignore ) def module(request: Any) -> Any: return import_module(request.param) @@ -52,7 +52,7 @@ def module(request: Any) -> Any: ), ), ], - ids=lambda x: x[0], + ids=lambda x: x[0], # type: ignore ) def tested_type(module: Any, request: Any) -> Any: name = request.param[0] @@ -61,5 +61,5 @@ def tested_type(module: Any, request: Any) -> Any: return {"type": getattr(module, name), "expected": expected} -def test_get_dataclass_ir(tested_type: Any): +def test_get_dataclass_ir(tested_type: Any) -> None: assert get_dataclass_ir(tested_type["type"]) == tested_type["expected"] From 3c97fbd29023e8e9af8290b81672bf1c0095aead Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Sat, 6 Feb 2021 11:40:31 -0600 Subject: [PATCH 3/5] support for attr classes --- omegaconf/ir.py | 37 +++++++++++++++++++++++++++++++++++++ tests/ir/test_ir.py | 4 ++-- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/omegaconf/ir.py b/omegaconf/ir.py index 279c7b324..45683e359 100644 --- a/omegaconf/ir.py +++ b/omegaconf/ir.py @@ -2,10 +2,13 @@ from dataclasses import dataclass from typing import Any, Optional, get_type_hints +import attr + from omegaconf._utils import ( _resolve_forward, _resolve_optional, get_type_of, + is_attr_class, is_dataclass, ) @@ -48,3 +51,37 @@ def get_dataclass_ir(obj: Any) -> IRNode: children.append(ir) return IRNode(name=None, val=children, type=obj_type, opt=False) + + +def get_attr_ir(obj: Any) -> IRNode: + from omegaconf.omegaconf import MISSING + + resolved_hints = get_type_hints(get_type_of(obj)) + assert is_attr_class(obj) + obj_type = get_type_of(obj) + children = [] + for name in attr.fields_dict(obj).keys(): + # for fld in dataclasses.fields(obj): + # name = fld.name + opt, type_ = _resolve_optional(resolved_hints[name]) + type_ = _resolve_forward(type_, obj_type.__module__) + + if hasattr(obj, name): + value = getattr(obj, name) + if value == attr.NOTHING: + value = MISSING + else: + value = MISSING + ir = IRNode(name=name, type=type_, opt=opt, val=value) + children.append(ir) + + return IRNode(name=None, val=children, type=obj_type, opt=False) + + +def get_structured_config_ir(obj: Any) -> IRNode: + if is_dataclass(obj): + return get_dataclass_ir(obj) + elif is_attr_class(obj): + return get_attr_ir(obj) + else: + raise ValueError(f"Unsupported type: {type(obj).__name__}") diff --git a/tests/ir/test_ir.py b/tests/ir/test_ir.py index 958ee10f2..f663251a0 100644 --- a/tests/ir/test_ir.py +++ b/tests/ir/test_ir.py @@ -5,7 +5,7 @@ from pytest import fixture from omegaconf import MISSING -from omegaconf.ir import IRNode, get_dataclass_ir +from omegaconf.ir import IRNode, get_structured_config_ir def resolve_types(module: Any, ir: IRNode) -> None: @@ -62,4 +62,4 @@ def tested_type(module: Any, request: Any) -> Any: def test_get_dataclass_ir(tested_type: Any) -> None: - assert get_dataclass_ir(tested_type["type"]) == tested_type["expected"] + assert get_structured_config_ir(tested_type["type"]) == tested_type["expected"] From 71454d1eeab65673dc7163966dcb76ba65ff7e98 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Sat, 6 Feb 2021 12:37:00 -0600 Subject: [PATCH 4/5] add more tests for code coverage --- omegaconf/ir.py | 19 +++++++++------- tests/ir/data/attr.py | 19 ++++++++++++++++ tests/ir/data/dataclass.py | 21 +++++++++++++++++- tests/ir/test_ir.py | 44 +++++++++++++++++++++++++++++++++++++- 4 files changed, 93 insertions(+), 10 deletions(-) diff --git a/omegaconf/ir.py b/omegaconf/ir.py index 45683e359..26b668681 100644 --- a/omegaconf/ir.py +++ b/omegaconf/ir.py @@ -2,8 +2,6 @@ from dataclasses import dataclass from typing import Any, Optional, get_type_hints -import attr - from omegaconf._utils import ( _resolve_forward, _resolve_optional, @@ -54,24 +52,29 @@ def get_dataclass_ir(obj: Any) -> IRNode: def get_attr_ir(obj: Any) -> IRNode: + import attr + import attr._make + from omegaconf.omegaconf import MISSING resolved_hints = get_type_hints(get_type_of(obj)) assert is_attr_class(obj) obj_type = get_type_of(obj) children = [] - for name in attr.fields_dict(obj).keys(): + for name, attrib in attr.fields_dict(obj).items(): # for fld in dataclasses.fields(obj): # name = fld.name opt, type_ = _resolve_optional(resolved_hints[name]) type_ = _resolve_forward(type_, obj_type.__module__) - if hasattr(obj, name): - value = getattr(obj, name) - if value == attr.NOTHING: - value = MISSING - else: + assert not hasattr(obj, name) # no test coverage for this case yet + if attrib.default == attr.NOTHING: value = MISSING + elif isinstance(attrib.default, attr._make.Factory): + assert not attrib.default.takes_self, "not supported yet" + value = attrib.default.factory() + else: + value = attrib.default ir = IRNode(name=name, type=type_, opt=opt, val=value) children.append(ir) diff --git a/tests/ir/data/attr.py b/tests/ir/data/attr.py index f45e0c49c..3008d032d 100644 --- a/tests/ir/data/attr.py +++ b/tests/ir/data/attr.py @@ -1,4 +1,5 @@ import attr +from attr import NOTHING as backend_MISSING from omegaconf import MISSING @@ -13,3 +14,21 @@ class User: class UserWithMissing: name: str = MISSING age: int = MISSING + + +@attr.s(auto_attribs=True) +class UserWithBackendMissing: + name: str = backend_MISSING # type: ignore + age: int = backend_MISSING # type: ignore + + +@attr.s(auto_attribs=True) +class UserWithDefault: + name: str = "bond" + age: int = 7 + + +@attr.s(auto_attribs=True) +class UserWithDefaultFactory: + name: str = attr.ib(factory=lambda: "bond") + age: int = attr.ib(factory=lambda: 7) diff --git a/tests/ir/data/dataclass.py b/tests/ir/data/dataclass.py index 35eb874cd..f3c210d88 100644 --- a/tests/ir/data/dataclass.py +++ b/tests/ir/data/dataclass.py @@ -1,4 +1,5 @@ -from dataclasses import dataclass +from dataclasses import MISSING as backend_MISSING +from dataclasses import dataclass, field from omegaconf import MISSING @@ -13,3 +14,21 @@ class User: class UserWithMissing: name: str = MISSING age: int = MISSING + + +@dataclass +class UserWithBackendMissing: + name: str = backend_MISSING # type: ignore + age: int = backend_MISSING # type: ignore + + +@dataclass +class UserWithDefault: + name: str = "bond" + age: int = 7 + + +@dataclass +class UserWithDefaultFactory: + name: str = field(default_factory=lambda: "bond") + age: int = field(default_factory=lambda: 7) diff --git a/tests/ir/test_ir.py b/tests/ir/test_ir.py index f663251a0..4aaa9c650 100644 --- a/tests/ir/test_ir.py +++ b/tests/ir/test_ir.py @@ -2,7 +2,7 @@ from importlib import import_module from typing import Any -from pytest import fixture +from pytest import fixture, raises from omegaconf import MISSING from omegaconf.ir import IRNode, get_structured_config_ir @@ -51,6 +51,42 @@ def module(request: Any) -> Any: ], ), ), + ( + "UserWithBackendMissing", + IRNode( + name=None, + type="UserWithBackendMissing", + opt=False, + val=[ + IRNode(name="name", type=str, opt=False, val=MISSING), + IRNode(name="age", type=int, opt=False, val=MISSING), + ], + ), + ), + ( + "UserWithDefault", + IRNode( + name=None, + type="UserWithDefault", + opt=False, + val=[ + IRNode(name="name", type=str, opt=False, val="bond"), + IRNode(name="age", type=int, opt=False, val=7), + ], + ), + ), + ( + "UserWithDefaultFactory", + IRNode( + name=None, + type="UserWithDefaultFactory", + opt=False, + val=[ + IRNode(name="name", type=str, opt=False, val="bond"), + IRNode(name="age", type=int, opt=False, val=7), + ], + ), + ), ], ids=lambda x: x[0], # type: ignore ) @@ -63,3 +99,9 @@ def tested_type(module: Any, request: Any) -> Any: def test_get_dataclass_ir(tested_type: Any) -> None: assert get_structured_config_ir(tested_type["type"]) == tested_type["expected"] + + +def test_get_structured_config_ir_rejects_nonstructured() -> None: + """`get_structured_config_ir` should reject input that is not structured""" + with raises(ValueError): + get_structured_config_ir(object()) From 50145b897d48916824c31da295cc5da1a9a8f92e Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Wed, 28 Apr 2021 15:35:49 -0500 Subject: [PATCH 5/5] comments to silence lgtm alerts --- omegaconf/_utils.py | 2 +- omegaconf/ir.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 1b9f461b6..6246d6672 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -144,7 +144,7 @@ def _resolve_optional(type_: Any) -> Tuple[bool, Any]: args = type_.__args__ if len(args) == 2 and args[1] == type(None): # noqa E721 return True, args[0] - if type_ is Any: + if type_ is Any: # lgtm [py/comparison-using-is] return True, Any return False, type_ diff --git a/omegaconf/ir.py b/omegaconf/ir.py index 26b668681..ce71ebf2f 100644 --- a/omegaconf/ir.py +++ b/omegaconf/ir.py @@ -1,4 +1,4 @@ -import dataclasses +import dataclasses # lgtm [py/import-and-import-from] from dataclasses import dataclass from typing import Any, Optional, get_type_hints