diff --git a/pyproject.toml b/pyproject.toml index 6445b34..cf425db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] dependencies = [ - "Sphinx >= 5.3.0", + "Sphinx >= 7", "Unidecode >= 1.3.6", "black", ] diff --git a/tests/test_objects_get_vars.py b/tests/test_objects_get_vars.py index faacaae..7c22286 100644 --- a/tests/test_objects_get_vars.py +++ b/tests/test_objects_get_vars.py @@ -1,5 +1,8 @@ import collections import enum +from typing import Generic, TypeVar + +import pytest import uqbar.objects @@ -20,19 +23,29 @@ class Enumeration(enum.IntEnum): BAZ = 3 -def test_objects_get_vars_01(): - my_object = MyObject("a", "b", "c", "d", foo="x", quux=["y", "z"]) - assert uqbar.objects.get_vars(my_object) == ( - collections.OrderedDict([("arg1", "a"), ("arg2", "b")]), - ["c", "d"], - {"bar": None, "foo": "x", "quux": ["y", "z"]}, - ) - - -def test_objects_get_vars_02(): - my_object = Enumeration.BAZ - assert uqbar.objects.get_vars(my_object) == ( - collections.OrderedDict([("value", 3)]), - [], - {}, - ) +T = TypeVar("T") + + +class MyGeneric(Generic[T]): + def __init__(self, arg: T) -> None: + self.arg = arg + + +@pytest.mark.parametrize( + "object_, expected", + [ + ( + MyObject("a", "b", "c", "d", foo="x", quux=["y", "z"]), + ( + collections.OrderedDict([("arg1", "a"), ("arg2", "b")]), + ["c", "d"], + {"bar": None, "foo": "x", "quux": ["y", "z"]}, + ), + ), + (Enumeration.BAZ, (collections.OrderedDict([("value", 3)]), [], {})), + (MyGeneric[int](arg=3), (collections.OrderedDict([("arg", 3)]), [], {})), + ], +) +def test_objects_get_vars(object_, expected): + actual = uqbar.objects.get_vars(object_) + assert actual == expected diff --git a/uqbar/_version.py b/uqbar/_version.py index 09e40b5..7ed45ff 100644 --- a/uqbar/_version.py +++ b/uqbar/_version.py @@ -1,2 +1,2 @@ -__version_info__ = (0, 7, 2) +__version_info__ = (0, 7, 3) __version__ = ".".join(str(x) for x in __version_info__) diff --git a/uqbar/objects.py b/uqbar/objects.py index 08f2702..1005001 100644 --- a/uqbar/objects.py +++ b/uqbar/objects.py @@ -1,5 +1,8 @@ import collections import inspect +from typing import Any, Dict, Optional, TypeVar + +T = TypeVar("T") def _dispatch_formatting(expr): @@ -9,21 +12,7 @@ def _dispatch_formatting(expr): def _get_object_signature(expr): - expr = type(expr) - # print('E-I-ID', id(expr.__init__)) - # print('E-N-ID', id(expr.__new__)) - # print('o-I-ID', id(object.__init__)) - # print('o-N-ID', id(object.__new__)) - # print('IEQ?', expr.__init__ == object.__init__) - # print('NEQ?', expr.__new__ == object.__new__) - # attrs = {_.name: _ for _ in inspect.classify_class_attrs(expr)} - # print('I?', attrs['__init__']) - # print('N?', attrs['__new__']) - if expr.__new__ is not object.__new__: - return inspect.signature(expr.__new__) - if expr.__init__ is not object.__init__: - return inspect.signature(expr.__init__) - return None + return inspect.signature(type(expr).__init__) def _get_sequence_repr(expr): @@ -84,7 +73,9 @@ def get_hash(expr): return hash(tuple(hash_values)) -def get_repr(expr, multiline=None, suppress_defaults=True): +def get_repr( + expr, multiline: Optional[bool] = None, suppress_defaults: bool = True +) -> str: """ Build a repr string for ``expr`` from its vars and signature. @@ -166,11 +157,9 @@ def get_repr(expr, multiline=None, suppress_defaults=True): for i, part in enumerate(parts): parts[i] = "\n".join(" " + line for line in part.split("\n")) parts.append(")") - parts = ",\n".join(parts) - return "{}(\n{}".format(type(expr).__name__, parts) + return "{}(\n{}".format(type(expr).__name__, ",\n".join(parts)) - parts = ", ".join(parts) - return "{}({})".format(type(expr).__name__, parts) + return "{}({})".format(type(expr).__name__, ", ".join(parts)) def get_vars(expr): @@ -211,88 +200,87 @@ def get_vars(expr): {'foo': 'x', 'bar': None, 'quux': ['y', 'z']} """ - # print('TYPE?', type(expr)) - signature = _get_object_signature(expr) - if signature is None: - return ({}, [], {}) - # print('SIG?', signature) - args = collections.OrderedDict() - var_args = [] - kwargs = {} - if expr is None: - return args, var_args, kwargs - for i, (name, parameter) in enumerate(signature.parameters.items()): - # print(' ', parameter, parameter.kind, parameter.default) - - if i == 0 and name in ("self", "cls", "class_", "klass"): - continue - - if parameter.kind is inspect._POSITIONAL_ONLY: - try: - args[name] = getattr(expr, name) - except AttributeError: - args[name] = expr[name] - elif parameter.kind in (inspect._POSITIONAL_OR_KEYWORD, inspect._KEYWORD_ONLY): - for x in (name, "_" + name): + def _get_vars(signature): + args = collections.OrderedDict() + var_args = [] + kwargs = {} + if expr is None: + return args, var_args, kwargs + for i, (name, parameter) in enumerate(signature.parameters.items()): + # print(' ', parameter, parameter.kind, parameter.default) + if i == 0 and name in ("self", "cls", "class_", "klass"): + continue + if parameter.kind is inspect._POSITIONAL_ONLY: try: - value = getattr(expr, x) - break + args[name] = getattr(expr, name) except AttributeError: + args[name] = expr[name] + elif parameter.kind in ( + inspect._POSITIONAL_OR_KEYWORD, + inspect._KEYWORD_ONLY, + ): + for x in (name, "_" + name): try: - value = expr[x] + value = getattr(expr, x) break - except (KeyError, TypeError): - pass - else: - raise ValueError("Cannot find value for {!r}".format(name)) - # print(" ", value) - if parameter.kind is inspect._KEYWORD_ONLY: - # print(" ??? A") + except AttributeError: + try: + value = expr[x] + break + except (KeyError, TypeError): + pass + else: + raise ValueError("Cannot find value for {!r}".format(name)) + # print(" ", value) + if parameter.kind is inspect._KEYWORD_ONLY: + # print(" ??? A") + kwargs[name] = value + continue + elif parameter.default is inspect._empty: + # print(" ??? B") + args[name] = value + continue + elif parameter.default == value: + # print(" ??? C") + kwargs[name] = value + continue + # print(" ??? D") kwargs[name] = value - continue - elif parameter.default is inspect._empty: - # print(" ??? B") - args[name] = value - continue - elif parameter.default == value: - # print(" ??? C") - kwargs[name] = value - continue - # print(" ??? D") - kwargs[name] = value - - elif parameter.kind is inspect._VAR_POSITIONAL: - value = None - try: - value = expr[:] - except TypeError: - value = getattr(expr, name) - if value: - var_args.extend(value) - - elif parameter.kind is inspect._VAR_KEYWORD: - items = {} - if hasattr(expr, "items"): - items = expr.items() - elif hasattr(expr, name): - mapping = getattr(expr, name) - if not isinstance(mapping, dict): - mapping = dict(mapping) - items = mapping.items() - elif hasattr(expr, "_" + name): - mapping = getattr(expr, "_" + name) - if not isinstance(mapping, dict): - mapping = dict(mapping) - items = mapping.items() - for key, value in items: - if key not in args: - kwargs[key] = value - - return args, var_args, kwargs - - -def new(expr, *args, **kwargs): + elif parameter.kind is inspect._VAR_POSITIONAL: + value = None + try: + value = expr[:] + except TypeError: + value = getattr(expr, name) + if value: + var_args.extend(value) + elif parameter.kind is inspect._VAR_KEYWORD: + items = {} + if hasattr(expr, "items"): + items = expr.items() + elif hasattr(expr, name): + mapping = getattr(expr, name) + if not isinstance(mapping, dict): + mapping = dict(mapping) + items = mapping.items() + elif hasattr(expr, "_" + name): + mapping = getattr(expr, "_" + name) + if not isinstance(mapping, dict): + mapping = dict(mapping) + items = mapping.items() + for key, value in items: + if key not in args: + kwargs[key] = value + return args, var_args, kwargs + + try: + return _get_vars(inspect.signature(type(expr).__init__)) + except AttributeError: + return _get_vars(inspect.signature(type(expr).__new__)) + + +def new(expr: T, *args, **kwargs) -> T: """ Template an object. @@ -343,7 +331,7 @@ def new(expr, *args, **kwargs): current_args, current_var_args, current_kwargs = get_vars(expr) new_kwargs = current_kwargs.copy() - recursive_arguments = {} + recursive_arguments: Dict[str, Any] = {} for key in tuple(kwargs): if "__" in key: value = kwargs.pop(key) diff --git a/uqbar/sphinx/inheritance.py b/uqbar/sphinx/inheritance.py index bb67cee..0b5e52e 100644 --- a/uqbar/sphinx/inheritance.py +++ b/uqbar/sphinx/inheritance.py @@ -33,7 +33,7 @@ import os import pathlib import subprocess -from typing import Any, Dict, List, Mapping, cast +from typing import Any, Dict, List, Mapping, Optional, cast from docutils.nodes import Element, General, Node, SkipNode from docutils.parsers.rst import Directive, directives @@ -114,7 +114,7 @@ def build_urls(self: HTMLTranslator, node: inheritance_diagram) -> Mapping[str, if not isinstance(child, Element): continue # Another document - refuri: str | None + refuri: Optional[str] if (refuri := child.attributes.get("refuri")) is not None: package_path: str = child["reftitle"] if refuri.startswith("http"):