diff --git a/dataconf/main.py b/dataconf/main.py index f05e514..28faab2 100644 --- a/dataconf/main.py +++ b/dataconf/main.py @@ -1,3 +1,4 @@ +import inspect import os from typing import List @@ -9,11 +10,41 @@ import pyparsing +def inject_callee_scope(func): + def inner(*args, **kwargs): + noglobals = "globalns" not in kwargs + nolocals = "localns" not in kwargs + + if noglobals or nolocals: + frame = inspect.stack()[1][0] + + if noglobals: + kwargs["globalns"] = frame.f_globals + if nolocals: + kwargs["localns"] = frame.f_locals + + return func( + *args, + **kwargs, + ) + + return inner + + +@inject_callee_scope def parse( - conf: ConfigTree, clazz, strict: bool = True, ignore_unexpected: bool = False + conf: ConfigTree, + clazz, + strict: bool = True, + ignore_unexpected: bool = False, + globalns=None, + localns=None, ): + try: - return utils.__parse(conf, clazz, "", strict, ignore_unexpected) + return utils.__parse( + conf, clazz, "", strict, ignore_unexpected, globalns, localns + ) except pyparsing.ParseSyntaxException as e: raise MalformedConfigException( f'parsing failure line {e.lineno} character {e.col}, got "{e.line}"' @@ -50,42 +81,52 @@ def file(self, path: str, **kwargs) -> "Multi": conf = ConfigFactory.parse_file(path) return Multi(self.confs + [conf], self.strict, **kwargs) - def on(self, clazz): + @inject_callee_scope + def on(self, clazz, globalns=None, localns=None): conf, *nxts = self.confs for nxt in nxts: conf = ConfigTree.merge_configs(conf, nxt) - return parse(conf, clazz, self.strict, **self.kwargs) + return parse( + conf, clazz, self.strict, globalns=globalns, localns=localns, **self.kwargs + ) multi = Multi([]) -def env(prefix: str, clazz, **kwargs): - return multi.env(prefix, **kwargs).on(clazz) +@inject_callee_scope +def env(prefix: str, clazz, globalns=None, localns=None, **kwargs): + return multi.env(prefix, **kwargs).on(clazz, globalns=globalns, localns=localns) -def dict(obj: str, clazz, **kwargs): - return multi.dict(obj, **kwargs).on(clazz) +@inject_callee_scope +def dict(obj: str, clazz, globalns=None, localns=None, **kwargs): + return multi.dict(obj, **kwargs).on(clazz, globalns=globalns, localns=localns) -def string(s: str, clazz, **kwargs): - return multi.string(s, **kwargs).on(clazz) +@inject_callee_scope +def string(s: str, clazz, globalns=None, localns=None, **kwargs): + return multi.string(s, **kwargs).on(clazz, globalns=globalns, localns=localns) -def url(uri: str, clazz, **kwargs): - return multi.url(uri, **kwargs).on(clazz) +@inject_callee_scope +def url(uri: str, clazz, globalns=None, localns=None, **kwargs): + return multi.url(uri, **kwargs).on(clazz, globalns=globalns, localns=localns) -def file(path: str, clazz, **kwargs): - return multi.file(path, **kwargs).on(clazz) +@inject_callee_scope +def file(path: str, clazz, globalns=None, localns=None, **kwargs): + return multi.file(path, **kwargs).on(clazz, globalns=globalns, localns=localns) -def load(path: str, clazz, **kwargs): - return file(path, clazz, **kwargs) +@inject_callee_scope +def load(path: str, clazz, globalns=None, localns=None, **kwargs): + return file(path, clazz, globalns=globalns, localns=localns, **kwargs) -def loads(s: str, clazz, **kwargs): - return string(s, clazz, **kwargs) +@inject_callee_scope +def loads(s: str, clazz, globalns=None, localns=None, **kwargs): + return string(s, clazz, globalns=globalns, localns=localns, **kwargs) def dump(file: str, instance: object, out: str): diff --git a/dataconf/utils.py b/dataconf/utils.py index cea6272..aecfc24 100644 --- a/dataconf/utils.py +++ b/dataconf/utils.py @@ -3,6 +3,7 @@ from dataclasses import fields from dataclasses import is_dataclass from datetime import datetime +import typing from typing import get_args from typing import get_origin from typing import Union @@ -39,7 +40,7 @@ def is_optional(type): return get_origin(type) is Union and NoneType in get_args(type) -def __parse(value: any, clazz, path, strict, ignore_unexpected): +def __parse(value: any, clazz, path, strict, ignore_unexpected, globalns, localns): if is_dataclass(clazz): @@ -51,6 +52,7 @@ def __parse(value: any, clazz, path, strict, ignore_unexpected): fs = {} renamings = dict() + type_hints = typing.get_type_hints(clazz, globalns, localns) for f in fields(clazz): if f.name in value: @@ -66,10 +68,16 @@ def __parse(value: any, clazz, path, strict, ignore_unexpected): if not isinstance(val, _MISSING_TYPE): fs[f.name] = __parse( - val, f.type, f"{path}.{f.name}", strict, ignore_unexpected + val, + type_hints[f.name], + f"{path}.{f.name}", + strict, + ignore_unexpected, + globalns, + localns, ) - elif is_optional(f.type): + elif is_optional(type_hints[f.name]): # Optional not found fs[f.name] = None @@ -94,7 +102,15 @@ def __parse(value: any, clazz, path, strict, ignore_unexpected): raise MissingTypeException("expected list with type information: List[?]") if value is not None: return [ - __parse(v, args[0], f"{path}[]", strict, ignore_unexpected) + __parse( + v, + args[0], + f"{path}[]", + strict, + ignore_unexpected, + globalns, + localns, + ) for v in value ] return None @@ -106,7 +122,15 @@ def __parse(value: any, clazz, path, strict, ignore_unexpected): ) if value is not None: return { - k: __parse(v, args[1], f"{path}.{k}", strict, ignore_unexpected) + k: __parse( + v, + args[1], + f"{path}.{k}", + strict, + ignore_unexpected, + globalns, + localns, + ) for k, v in value.items() } return None @@ -120,6 +144,8 @@ def __parse(value: any, clazz, path, strict, ignore_unexpected): path, strict, ignore_unexpected, + globalns, + localns, ) except TypeConfigException: # cannot parse Optional @@ -129,10 +155,14 @@ def __parse(value: any, clazz, path, strict, ignore_unexpected): left, right = args try: - return __parse(value, left, path, strict, ignore_unexpected) + return __parse( + value, left, path, strict, ignore_unexpected, globalns, localns + ) except TypeConfigException as left_failure: try: - return __parse(value, right, path, strict, ignore_unexpected) + return __parse( + value, right, path, strict, ignore_unexpected, globalns, localns + ) except TypeConfigException as right_failure: raise TypeConfigException( f"expected type {clazz} at {path}, failed both:\n- {left_failure}\n- {right_failure}" @@ -186,7 +216,15 @@ def __parse(value: any, clazz, path, strict, ignore_unexpected): for child_clazz in sorted(clazz.__subclasses__(), key=lambda c: c.__name__): if is_dataclass(child_clazz): try: - return __parse(value, child_clazz, path, strict, ignore_unexpected) + return __parse( + value, + child_clazz, + path, + strict, + ignore_unexpected, + globalns, + localns, + ) except ( TypeConfigException, MalformedConfigException, diff --git a/tests/test_futur_annotations.py b/tests/test_futur_annotations.py new file mode 100644 index 0000000..19fe2c6 --- /dev/null +++ b/tests/test_futur_annotations.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from dataclasses import dataclass +import os +from typing import get_type_hints +from typing import Text + +import dataconf +from dataconf.main import inject_callee_scope + + +@inject_callee_scope +def out_of_scope_assert(clazz, expected, globalns, localns): + assert get_type_hints(clazz, globalns, localns)["a"] is expected + + +class TestFuturAnnotations: + def test_43(self) -> None: + @dataclass + class Model: + token: str + + os.environ["TEST_token"] = "1" + dataconf.env("TEST_", Model) + + def test_repro(self) -> None: + @dataclass + class A: + value: Text + + @dataclass + class B: + a: A + + out_of_scope_assert(B, A, globalns={})