Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Structured Configs IR : WIP #514

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
args = type_.__args__
if len(args) == 2 and args[1] == type(None): # noqa E721
return True, args[0]
if type_ is Any:
if type_ is Any: # lgtm [py/comparison-using-is]
return True, Any

return False, type_
Expand Down
90 changes: 90 additions & 0 deletions omegaconf/ir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import dataclasses # lgtm [py/import-and-import-from]
from dataclasses import dataclass
from typing import Any, Optional, get_type_hints

from omegaconf._utils import (
_resolve_forward,
_resolve_optional,
get_type_of,
is_attr_class,
is_dataclass,
)


@dataclass
class IR:
pass


@dataclass
class IRNode(IR):
name: Optional[str]
type: Any
val: Any
opt: bool


def get_dataclass_ir(obj: Any) -> IRNode:
from omegaconf.omegaconf import MISSING

resolved_hints = get_type_hints(get_type_of(obj))
assert is_dataclass(obj)
obj_type = get_type_of(obj)
children = []
for fld in dataclasses.fields(obj):
name = fld.name
opt, type_ = _resolve_optional(resolved_hints[name])
type_ = _resolve_forward(type_, fld.__module__)

if hasattr(obj, name):
value = getattr(obj, name)
if value == dataclasses.MISSING:
value = MISSING
else:
if fld.default_factory == dataclasses.MISSING: # type: ignore
value = MISSING
else:
value = fld.default_factory() # type: ignore
ir = IRNode(name=name, type=type_, opt=opt, val=value)
children.append(ir)

return IRNode(name=None, val=children, type=obj_type, opt=False)


def get_attr_ir(obj: Any) -> IRNode:
import attr
import attr._make

from omegaconf.omegaconf import MISSING

resolved_hints = get_type_hints(get_type_of(obj))
assert is_attr_class(obj)
obj_type = get_type_of(obj)
children = []
for name, attrib in attr.fields_dict(obj).items():
# for fld in dataclasses.fields(obj):
# name = fld.name
opt, type_ = _resolve_optional(resolved_hints[name])
type_ = _resolve_forward(type_, obj_type.__module__)

assert not hasattr(obj, name) # no test coverage for this case yet
if attrib.default == attr.NOTHING:
value = MISSING
elif isinstance(attrib.default, attr._make.Factory):
assert not attrib.default.takes_self, "not supported yet"
value = attrib.default.factory()
else:
value = attrib.default
ir = IRNode(name=name, type=type_, opt=opt, val=value)
children.append(ir)

return IRNode(name=None, val=children, type=obj_type, opt=False)


def get_structured_config_ir(obj: Any) -> IRNode:
if is_dataclass(obj):
return get_dataclass_ir(obj)
elif is_attr_class(obj):
return get_attr_ir(obj)
else:
raise ValueError(f"Unsupported type: {type(obj).__name__}")
Empty file added tests/ir/__init__.py
Empty file.
Empty file added tests/ir/data/__init__.py
Empty file.
34 changes: 34 additions & 0 deletions tests/ir/data/attr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import attr
from attr import NOTHING as backend_MISSING

from omegaconf import MISSING


@attr.s(auto_attribs=True)
class User:
name: str
age: int


@attr.s(auto_attribs=True)
class UserWithMissing:
name: str = MISSING
age: int = MISSING


@attr.s(auto_attribs=True)
class UserWithBackendMissing:
name: str = backend_MISSING # type: ignore
age: int = backend_MISSING # type: ignore


@attr.s(auto_attribs=True)
class UserWithDefault:
name: str = "bond"
age: int = 7


@attr.s(auto_attribs=True)
class UserWithDefaultFactory:
name: str = attr.ib(factory=lambda: "bond")
age: int = attr.ib(factory=lambda: 7)
34 changes: 34 additions & 0 deletions tests/ir/data/dataclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dataclasses import MISSING as backend_MISSING
from dataclasses import dataclass, field

from omegaconf import MISSING


@dataclass
class User:
name: str
age: int


@dataclass
class UserWithMissing:
name: str = MISSING
age: int = MISSING


@dataclass
class UserWithBackendMissing:
name: str = backend_MISSING # type: ignore
age: int = backend_MISSING # type: ignore


@dataclass
class UserWithDefault:
name: str = "bond"
age: int = 7


@dataclass
class UserWithDefaultFactory:
name: str = field(default_factory=lambda: "bond")
age: int = field(default_factory=lambda: 7)
107 changes: 107 additions & 0 deletions tests/ir/test_ir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import copy
from importlib import import_module
from typing import Any

from pytest import fixture, raises

from omegaconf import MISSING
from omegaconf.ir import IRNode, get_structured_config_ir


def resolve_types(module: Any, ir: IRNode) -> None:
if isinstance(ir.type, str):
ir.type = getattr(module, ir.type)

if isinstance(ir.val, list):
for c in ir.val:
resolve_types(module, c)


@fixture(
params=["tests.ir.data.dataclass", "tests.ir.data.attr"],
ids=lambda x: x.split(".")[-1], # type: ignore
)
def module(request: Any) -> Any:
return import_module(request.param)


@fixture(
params=[
(
"User",
IRNode(
name=None,
type="User",
opt=False,
val=[
IRNode(name="name", type=str, opt=False, val=MISSING),
IRNode(name="age", type=int, opt=False, val=MISSING),
],
),
),
(
"UserWithMissing",
IRNode(
name=None,
type="UserWithMissing",
opt=False,
val=[
IRNode(name="name", type=str, opt=False, val=MISSING),
IRNode(name="age", type=int, opt=False, val=MISSING),
],
),
),
(
"UserWithBackendMissing",
IRNode(
name=None,
type="UserWithBackendMissing",
opt=False,
val=[
IRNode(name="name", type=str, opt=False, val=MISSING),
IRNode(name="age", type=int, opt=False, val=MISSING),
],
),
),
(
"UserWithDefault",
IRNode(
name=None,
type="UserWithDefault",
opt=False,
val=[
IRNode(name="name", type=str, opt=False, val="bond"),
IRNode(name="age", type=int, opt=False, val=7),
],
),
),
(
"UserWithDefaultFactory",
IRNode(
name=None,
type="UserWithDefaultFactory",
opt=False,
val=[
IRNode(name="name", type=str, opt=False, val="bond"),
IRNode(name="age", type=int, opt=False, val=7),
],
),
),
],
ids=lambda x: x[0], # type: ignore
)
def tested_type(module: Any, request: Any) -> Any:
name = request.param[0]
expected = copy.deepcopy(request.param[1])
resolve_types(module, expected)
return {"type": getattr(module, name), "expected": expected}


def test_get_dataclass_ir(tested_type: Any) -> None:
assert get_structured_config_ir(tested_type["type"]) == tested_type["expected"]


def test_get_structured_config_ir_rejects_nonstructured() -> None:
"""`get_structured_config_ir` should reject input that is not structured"""
with raises(ValueError):
get_structured_config_ir(object())
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class Foo:
assert _utils.is_dataclass(Foo)
assert _utils.is_dataclass(Foo())
assert not _utils.is_dataclass(10)

# TODO: dataclasses are now mandatory, clean this up.
mocker.patch("omegaconf._utils.dataclasses", None)
assert not _utils.is_dataclass(10)

Expand Down