Skip to content

Commit

Permalink
improve handling of recursive types
Browse files Browse the repository at this point in the history
  • Loading branch information
Aran-Fey committed Feb 2, 2024
1 parent 7f23fdf commit b395478
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 12 deletions.
2 changes: 1 addition & 1 deletion introspection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
New and improved introspection functions
"""

__version__ = "1.7.8"
__version__ = "1.7.9"

from .parameter import *
from .signature_ import *
Expand Down
64 changes: 53 additions & 11 deletions introspection/typing/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def resolve_forward_refs(
mode: Literal["eval", "getattr", "ast"] = "eval",
strict: bool = True,
max_depth: Optional[int] = None,
extra_globals: Mapping[str, object] = {},
treat_name_errors_as_imports: bool = False,
) -> TypeAnnotation:
...
Expand All @@ -61,7 +62,7 @@ def resolve_forward_refs(
...


def resolve_forward_refs( # type: ignore[wtf]
def resolve_forward_refs( # type: ignore
annotation: TypeAnnotation,
context: ForwardRefContext = None,
eval_: Optional[bool] = None,
Expand All @@ -70,7 +71,9 @@ def resolve_forward_refs( # type: ignore[wtf]
module: typing.Optional[types.ModuleType] = None,
mode: Literal["eval", "getattr", "ast"] = "eval",
max_depth: Optional[int] = None,
extra_globals: Mapping[str, object] = {},
treat_name_errors_as_imports: bool = False,
_currently_evaluating: AbstractSet[Tuple[str, int]] = set(),
) -> TypeAnnotation:
"""
Resolves forward references in a type annotation.
Expand Down Expand Up @@ -128,6 +131,8 @@ def recurse(annotation: TypeAnnotation) -> TypeAnnotation:
mode=mode,
strict=strict,
max_depth=max_depth - 1,
extra_globals=extra_globals,
_currently_evaluating=_currently_evaluating, # type: ignore
)

if isinstance(annotation, ForwardRef):
Expand All @@ -137,6 +142,13 @@ def recurse(annotation: TypeAnnotation) -> TypeAnnotation:
annotation = _get_forward_ref_code(annotation)

if isinstance(annotation, str):
# First, check if this exact same forward reference is already being evaluated - i.e. it is
# a recursive type.
key = (annotation, id(context))
if key in _currently_evaluating:
return annotation
_currently_evaluating = _currently_evaluating | {key}

scope: collections.ChainMap[str, object] = collections.ChainMap()

if context is None:
Expand All @@ -155,12 +167,15 @@ def recurse(annotation: TypeAnnotation) -> TypeAnnotation:
scope.maps.append(vars(module))

scope.maps.append(vars(builtins)) # type: ignore
scope.maps.append(extra_globals) # type: ignore

if treat_name_errors_as_imports:
from ._utils import ImporterDict

scope.maps.append(ImporterDict()) # type: ignore

# Note: Annotations can be strings inside of strings, like `"'int'"`. So evaluating it once
# isn't necessarily enough; make sure to call `recurse()` with the result!
if mode == "eval":
try:
# The globals must be a real dict, so the scope will be used as
Expand All @@ -182,17 +197,23 @@ def recurse(annotation: TypeAnnotation) -> TypeAnnotation:
for attr in attrs:
value = getattr(value, attr)

return value # type: ignore
return recurse(value) # type: ignore
except AttributeError:
pass
elif mode == "ast":
expr = ast.parse(annotation, mode="eval")
try:
result = _eval_ast(expr.body, scope, strict=strict, max_depth=max_depth)
result = _eval_ast(
expr.body,
scope,
strict=strict,
max_depth=max_depth,
treat_name_errors_as_imports=treat_name_errors_as_imports,
)
except Exception:
pass
else:
return result # type: ignore
return recurse(result) # type: ignore
else:
assert False, f"Invalid mode: {mode!r}"

Expand Down Expand Up @@ -263,20 +284,22 @@ def recurse(annotation: TypeAnnotation) -> TypeAnnotation:


def _eval_ast(
node: ast.AST, scope: typing.Mapping[str, object], strict: bool, max_depth: int
node: ast.AST,
scope: typing.Mapping[str, object],
strict: bool,
max_depth: int,
treat_name_errors_as_imports: bool,
) -> object:
# Compared to "eval" and "getattr", this method of evaluating forward refs
# has the advantage of being able to perform partial evaluation. For
# example, the forward ref `"ClassVar[NameThatCannotBeResolved]"` can be
# turned into `ClassVar["NameThatCannotBeResolved"]`.
#
# Sometimes we need to know whether the forward ref was resolved or not.
# That's why this function returns a tuple of `(bool, object)`.

def recurse(node: ast.AST) -> object:
if max_depth <= 1:
return ast.unparse(node)

return _eval_ast(node, scope, strict, max_depth - 1)
return _eval_ast(node, scope, strict, max_depth - 1, treat_name_errors_as_imports)

if strict:
safe_recurse = recurse
Expand All @@ -298,15 +321,34 @@ def safe_recurse_if_forwardref(obj: object) -> object:
mode="ast",
strict=strict,
max_depth=max_depth - 1,
treat_name_errors_as_imports=False,
treat_name_errors_as_imports=treat_name_errors_as_imports,
)

if type(node) is ast.Name:
name = node.id
return scope[name]
elif type(node) is ast.Attribute:
obj = recurse(node.value)
return getattr(obj, node.attr)

try:
return getattr(obj, node.attr)
except AttributeError:
# The parameter name is a little misleading, but we'll also treat AttributeErrors on
# modules as missing imports
if not treat_name_errors_as_imports or not isinstance(obj, types.ModuleType):
raise

try:
importlib.import_module(f"{obj.__name__}.{node.attr}")
except ImportError:
pass

try:
return getattr(obj, node.attr)
except AttributeError:
pass

raise
elif type(node) is ast.Subscript:
generic_type = recurse(node.value)
subtype = safe_recurse(node.slice)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_typing/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@


T = TypeVar("T")
RecursiveType = List["RecursiveType"]

THIS_MODULE = sys.modules[__name__]

Expand Down Expand Up @@ -162,6 +163,8 @@ def test_literal_to_string_with_typing(annotation, expected):
("ellipsis", type(...)),
('List["int"]', List[int]),
('Literal["int"]', Literal["int"]), # `Literal` arguments must be left as strings
('"int"', int), # Double stringified
("RecursiveType", RecursiveType),
],
)
@pytest.mark.parametrize("mode", ["eval", "ast"])
Expand Down

0 comments on commit b395478

Please sign in to comment.