Skip to content

Commit

Permalink
resolve
Browse files Browse the repository at this point in the history
  • Loading branch information
tlambert03 committed Jun 19, 2024
1 parent 7062dc0 commit d25f898
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 13 deletions.
29 changes: 25 additions & 4 deletions src/psygnal/_group_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextlib
import copy
import inspect
import operator
import sys
import warnings
Expand All @@ -13,6 +14,7 @@
Any,
Callable,
ClassVar,
ForwardRef,
Iterable,
Literal,
Mapping,
Expand Down Expand Up @@ -587,10 +589,19 @@ def __set_name__(self, owner: type, name: str) -> None:
def _find_validators(self, owner: type) -> dict[str, list[Validator]]:
validators: dict[str, list[Validator]] = {}
for field, annotation in owner.__annotations__.items():
if get_origin(annotation) is Annotated:
for item in get_args(annotation)[1:]:
if isinstance(item, Validator):
validators.setdefault(field, []).append(item)
try:
annotation = _resolve(annotation, owner)
if get_origin(annotation) is Annotated:
for item in get_args(annotation)[1:]:
if isinstance(item, Validator):
validators.setdefault(field, []).append(item)
except Exception:
warnings.warn(
f"Unable to resolve type annotation {annotation}"
"Psygnal Validator will not work",
stacklevel=2,
)

return validators

def _do_patch_setattr(self, owner: type, with_aliases: bool = True) -> None:
Expand Down Expand Up @@ -725,3 +736,13 @@ def __call__(self, value: Any, *, name: str, owner: Any) -> Any:
f"Error setting value {value!r} for field {name!r} "
f"on type {type(owner)}: {e}"
) from e


def _resolve(annotation: Any, owner: Any) -> Any:
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
if isinstance(annotation, ForwardRef):
guard: frozenset = frozenset()
_globals = inspect.getmodule(owner).__dict__
annotation = annotation._evaluate(_globals, {}, guard)
return annotation
39 changes: 30 additions & 9 deletions tests/test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
from psygnal import Validator, evented


def test_validator():
def _is_positive(value: Any) -> int:
try:
_value = int(value)
except (ValueError, TypeError):
raise ValueError("Value must be an integer") from None
if not _value > 0:
raise ValueError("Value must be positive")
return _value
def _is_positive(value: Any) -> int:
try:
_value = int(value)
except (ValueError, TypeError):
raise ValueError("Value must be an integer") from None
if not _value > 0:
raise ValueError("Value must be positive")
return _value


def test_validator():
@evented
@dataclass
class Foo:
Expand All @@ -27,3 +28,23 @@ class Foo:
assert isinstance(foo.x, int)
with pytest.raises(ValueError):
foo.x = -1


def test_validator_resolution():
@evented
@dataclass
class Bar:
x: "Annotated[int, Validator(_is_positive)]"

with pytest.raises(ValueError, match="Value must be positive"):
Bar(x=-1)

def _local_func(value: Any) -> Any:
return value

with pytest.warns(UserWarning, match="Unable to resolve type"):

@evented
@dataclass
class Baz:
x: "Annotated[int, Validator(_local_func)]"

0 comments on commit d25f898

Please sign in to comment.