Skip to content

Commit

Permalink
fix: add PEP 563 support, see #43
Browse files Browse the repository at this point in the history
  • Loading branch information
zifeo committed Jun 1, 2022
1 parent d362282 commit 304e855
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 26 deletions.
77 changes: 59 additions & 18 deletions dataconf/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import os
from typing import List

Expand All @@ -9,11 +10,41 @@
import pyparsing


def inject_callee_scope(func):
def inner(*args, **kwargs):
noglobals = "globalns" not in kwargs
nolocals = "localns" not in kwargs

if noglobals or nolocals:
frame = inspect.stack()[1][0]

if noglobals:
kwargs["globalns"] = frame.f_globals
if nolocals:
kwargs["localns"] = frame.f_locals

return func(
*args,
**kwargs,
)

return inner


@inject_callee_scope
def parse(
conf: ConfigTree, clazz, strict: bool = True, ignore_unexpected: bool = False
conf: ConfigTree,
clazz,
strict: bool = True,
ignore_unexpected: bool = False,
globalns=None,
localns=None,
):

try:
return utils.__parse(conf, clazz, "", strict, ignore_unexpected)
return utils.__parse(
conf, clazz, "", strict, ignore_unexpected, globalns, localns
)
except pyparsing.ParseSyntaxException as e:
raise MalformedConfigException(
f'parsing failure line {e.lineno} character {e.col}, got "{e.line}"'
Expand Down Expand Up @@ -50,42 +81,52 @@ def file(self, path: str, **kwargs) -> "Multi":
conf = ConfigFactory.parse_file(path)
return Multi(self.confs + [conf], self.strict, **kwargs)

def on(self, clazz):
@inject_callee_scope
def on(self, clazz, globalns=None, localns=None):
conf, *nxts = self.confs
for nxt in nxts:
conf = ConfigTree.merge_configs(conf, nxt)
return parse(conf, clazz, self.strict, **self.kwargs)
return parse(
conf, clazz, self.strict, globalns=globalns, localns=localns, **self.kwargs
)


multi = Multi([])


def env(prefix: str, clazz, **kwargs):
return multi.env(prefix, **kwargs).on(clazz)
@inject_callee_scope
def env(prefix: str, clazz, globalns=None, localns=None, **kwargs):
return multi.env(prefix, **kwargs).on(clazz, globalns=globalns, localns=localns)


def dict(obj: str, clazz, **kwargs):
return multi.dict(obj, **kwargs).on(clazz)
@inject_callee_scope
def dict(obj: str, clazz, globalns=None, localns=None, **kwargs):
return multi.dict(obj, **kwargs).on(clazz, globalns=globalns, localns=localns)


def string(s: str, clazz, **kwargs):
return multi.string(s, **kwargs).on(clazz)
@inject_callee_scope
def string(s: str, clazz, globalns=None, localns=None, **kwargs):
return multi.string(s, **kwargs).on(clazz, globalns=globalns, localns=localns)


def url(uri: str, clazz, **kwargs):
return multi.url(uri, **kwargs).on(clazz)
@inject_callee_scope
def url(uri: str, clazz, globalns=None, localns=None, **kwargs):
return multi.url(uri, **kwargs).on(clazz, globalns=globalns, localns=localns)


def file(path: str, clazz, **kwargs):
return multi.file(path, **kwargs).on(clazz)
@inject_callee_scope
def file(path: str, clazz, globalns=None, localns=None, **kwargs):
return multi.file(path, **kwargs).on(clazz, globalns=globalns, localns=localns)


def load(path: str, clazz, **kwargs):
return file(path, clazz, **kwargs)
@inject_callee_scope
def load(path: str, clazz, globalns=None, localns=None, **kwargs):
return file(path, clazz, globalns=globalns, localns=localns, **kwargs)


def loads(s: str, clazz, **kwargs):
return string(s, clazz, **kwargs)
@inject_callee_scope
def loads(s: str, clazz, globalns=None, localns=None, **kwargs):
return string(s, clazz, globalns=globalns, localns=localns, **kwargs)


def dump(file: str, instance: object, out: str):
Expand Down
54 changes: 46 additions & 8 deletions dataconf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime
from typing import get_args
from typing import get_origin
from typing import get_type_hints
from typing import Union

from dataconf.exceptions import EnvListOrderException
Expand Down Expand Up @@ -39,7 +40,7 @@ def is_optional(type):
return get_origin(type) is Union and NoneType in get_args(type)


def __parse(value: any, clazz, path, strict, ignore_unexpected):
def __parse(value: any, clazz, path, strict, ignore_unexpected, globalns, localns):

if is_dataclass(clazz):

Expand All @@ -51,6 +52,7 @@ def __parse(value: any, clazz, path, strict, ignore_unexpected):
fs = {}
renamings = dict()

type_hints = get_type_hints(clazz, globalns, localns)
for f in fields(clazz):

if f.name in value:
Expand All @@ -66,10 +68,16 @@ def __parse(value: any, clazz, path, strict, ignore_unexpected):

if not isinstance(val, _MISSING_TYPE):
fs[f.name] = __parse(
val, f.type, f"{path}.{f.name}", strict, ignore_unexpected
val,
type_hints[f.name],
f"{path}.{f.name}",
strict,
ignore_unexpected,
globalns,
localns,
)

elif is_optional(f.type):
elif is_optional(type_hints[f.name]):
# Optional not found
fs[f.name] = None

Expand All @@ -94,7 +102,15 @@ def __parse(value: any, clazz, path, strict, ignore_unexpected):
raise MissingTypeException("expected list with type information: List[?]")
if value is not None:
return [
__parse(v, args[0], f"{path}[]", strict, ignore_unexpected)
__parse(
v,
args[0],
f"{path}[]",
strict,
ignore_unexpected,
globalns,
localns,
)
for v in value
]
return None
Expand All @@ -106,7 +122,15 @@ def __parse(value: any, clazz, path, strict, ignore_unexpected):
)
if value is not None:
return {
k: __parse(v, args[1], f"{path}.{k}", strict, ignore_unexpected)
k: __parse(
v,
args[1],
f"{path}.{k}",
strict,
ignore_unexpected,
globalns,
localns,
)
for k, v in value.items()
}
return None
Expand All @@ -120,6 +144,8 @@ def __parse(value: any, clazz, path, strict, ignore_unexpected):
path,
strict,
ignore_unexpected,
globalns,
localns,
)
except TypeConfigException:
# cannot parse Optional
Expand All @@ -129,10 +155,14 @@ def __parse(value: any, clazz, path, strict, ignore_unexpected):
left, right = args

try:
return __parse(value, left, path, strict, ignore_unexpected)
return __parse(
value, left, path, strict, ignore_unexpected, globalns, localns
)
except TypeConfigException as left_failure:
try:
return __parse(value, right, path, strict, ignore_unexpected)
return __parse(
value, right, path, strict, ignore_unexpected, globalns, localns
)
except TypeConfigException as right_failure:
raise TypeConfigException(
f"expected type {clazz} at {path}, failed both:\n- {left_failure}\n- {right_failure}"
Expand Down Expand Up @@ -186,7 +216,15 @@ def __parse(value: any, clazz, path, strict, ignore_unexpected):
for child_clazz in sorted(clazz.__subclasses__(), key=lambda c: c.__name__):
if is_dataclass(child_clazz):
try:
return __parse(value, child_clazz, path, strict, ignore_unexpected)
return __parse(
value,
child_clazz,
path,
strict,
ignore_unexpected,
globalns,
localns,
)
except (
TypeConfigException,
MalformedConfigException,
Expand Down
35 changes: 35 additions & 0 deletions tests/test_futur_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from dataclasses import dataclass
import os
from typing import get_type_hints
from typing import Text

import dataconf
from dataconf.main import inject_callee_scope


@inject_callee_scope
def out_of_scope_assert(clazz, expected, globalns, localns):
assert get_type_hints(clazz, globalns, localns)["a"] is expected


class TestFuturAnnotations:
def test_43(self) -> None:
@dataclass
class Model:
token: str

os.environ["TEST_token"] = "1"
dataconf.env("TEST_", Model)

def test_repro(self) -> None:
@dataclass
class A:
value: Text

@dataclass
class B:
a: A

out_of_scope_assert(B, A, globalns={})

0 comments on commit 304e855

Please sign in to comment.