From 95bf2e08991f85bbeb2785445845f0b8e07d0473 Mon Sep 17 00:00:00 2001 From: Abel Cheung Date: Sun, 20 Oct 2024 13:28:58 +0000 Subject: [PATCH] feat(test): Move subscript transform to each type checker adapter With subscript transformation handled by AST node transformer, it is now possible to handle nested subscript in type expression --- test-rt/_testutils/common.py | 26 +++++++++++++++++++-- test-rt/_testutils/mypy_adapter.py | 14 ++++++------ test-rt/_testutils/pyright_adapter.py | 9 ++++---- test-rt/_testutils/rt_wrapper.py | 33 ++++++++------------------- 4 files changed, 44 insertions(+), 38 deletions(-) diff --git a/test-rt/_testutils/common.py b/test-rt/_testutils/common.py index a0cb199..061adda 100644 --- a/test-rt/_testutils/common.py +++ b/test-rt/_testutils/common.py @@ -3,7 +3,7 @@ import importlib import pathlib import re -from typing import Any, ClassVar, ForwardRef, Iterable, NamedTuple +from typing import Any, ClassVar, ForwardRef, Iterable, NamedTuple, cast from lxml.etree import LXML_VERSION @@ -39,7 +39,7 @@ def __str__(self) -> str: return "{}(): {}".format(self._func, self.args[0]) -class NameCollectorBase(ast.NodeVisitor): +class NameCollectorBase(ast.NodeTransformer): def __init__( self, globalns: dict[str, Any], @@ -49,11 +49,33 @@ def __init__( self._globalns = globalns self._localns = localns self.modified: bool = False + # typing_extensions guaranteed to be present, + # as a dependency of typeguard self.collected: dict[str, Any] = { m: importlib.import_module(m) for m in ("builtins", "typing", "typing_extensions") } + def visit_Subscript(self, node: ast.Subscript) -> ast.expr: + node.value = cast("ast.expr", self.visit(node.value)) + node.slice = cast("ast.expr", self.visit(node.slice)) + + # When type reference is a stub-only specialized class + # which don't have runtime support (lxml classes have + # no __class_getitem__), concede by verifying + # unsubscripted type. + try: + eval(ast.unparse(node), self._globalns, self._localns | self.collected) + except TypeError as e: + if "is not subscriptable" not in e.args[0]: + raise + # TODO Insert node.value dependent hook for extra + # varification of subscript type + self.modified = True + return node.value + else: + return node + class TypeCheckerAdapterBase: id: ClassVar[str] diff --git a/test-rt/_testutils/mypy_adapter.py b/test-rt/_testutils/mypy_adapter.py index d5278b0..a910bf2 100644 --- a/test-rt/_testutils/mypy_adapter.py +++ b/test-rt/_testutils/mypy_adapter.py @@ -25,8 +25,8 @@ class _MypyDiagObj(_t.TypedDict): message: str -class _NameCollector(NameCollectorBase, ast.NodeTransformer): - def visit_Attribute(self, node: ast.Attribute) -> ast.AST: +class _NameCollector(NameCollectorBase): + def visit_Attribute(self, node: ast.Attribute) -> ast.expr: prefix = ast.unparse(node.value) name = node.attr @@ -49,9 +49,9 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.AST: self.modified = True return ast.Name(id=name, ctx=node.ctx) - node = _t.cast("ast.Attribute", self.generic_visit(node)) + _ = self.visit(node.value) - if (resolved := getattr(self.collected[prefix], name, False)): + if resolved := getattr(self.collected[prefix], name, False): self.collected[ast.unparse(node)] = resolved return node @@ -70,7 +70,7 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.AST: # but with a few exceptions (like tuple, Union). # visit_Attribute can ultimately recurse into visit_Name # as well - def visit_Name(self, node: ast.Name) -> ast.AST: + def visit_Name(self, node: ast.Name) -> ast.Name: name = node.id try: eval(name, self._globalns, self._localns | self.collected) @@ -96,10 +96,10 @@ def visit_Name(self, node: ast.Name) -> ast.AST: # For class defined inside local function scope, mypy outputs # something like "test_elem_class_lookup.FooClass@97". # Return only the left operand after processing. - def visit_BinOp(self, node: ast.BinOp) -> ast.AST: + def visit_BinOp(self, node: ast.BinOp) -> ast.expr: if isinstance(node.op, ast.MatMult) and isinstance(node.right, ast.Constant): # Mypy disallows returning Any - return _t.cast("ast.AST", self.visit(node.left)) + return _t.cast("ast.expr", self.visit(node.left)) # For expression that haven't been accounted for, just don't # process and allow name resolution to fail return node diff --git a/test-rt/_testutils/pyright_adapter.py b/test-rt/_testutils/pyright_adapter.py index 4b89722..83f949a 100644 --- a/test-rt/_testutils/pyright_adapter.py +++ b/test-rt/_testutils/pyright_adapter.py @@ -12,7 +12,7 @@ class _NameCollector(NameCollectorBase): # Pyright inferred type results always contain bare names only, # so don't need to bother with visit_Attribute() - def visit_Name(self, node: ast.Name) -> None: + def visit_Name(self, node: ast.Name) -> ast.Name: name = node.id try: eval(name, self._globalns, self._localns | self.collected) @@ -20,10 +20,9 @@ def visit_Name(self, node: ast.Name) -> None: for m in ("typing", "typing_extensions"): if hasattr(self.collected[m], name): self.collected[name] = getattr(self.collected[m], name) - break - else: - raise - + return node + raise + return node class _TypeCheckerAdapter(TypeCheckerAdapterBase): id = "pyright" diff --git a/test-rt/_testutils/rt_wrapper.py b/test-rt/_testutils/rt_wrapper.py index 6150ddc..41ad423 100644 --- a/test-rt/_testutils/rt_wrapper.py +++ b/test-rt/_testutils/rt_wrapper.py @@ -105,37 +105,22 @@ def reveal_type_wrapper(var: _T) -> _T: adapter.typechecker_result[pos] = VarType(var_name, tc_result.type) ref = tc_result.type - ref_ast = ast.parse(ref.__forward_arg__, mode="eval") - walker = adapter.create_collector(globalns, localns) - if isinstance(walker, ast.NodeTransformer): + try: + _ = eval(ref.__forward_arg__, globalns, localns) + except: + ref_ast = ast.parse(ref.__forward_arg__, mode="eval") + walker = adapter.create_collector(globalns, localns) new_ast = walker.visit(ref_ast) if walker.modified: - ref_ast = ast.fix_missing_locations(new_ast) - ref = _t.ForwardRef(ast.unparse(ref_ast)) + ref = _t.ForwardRef(ast.unparse(new_ast)) + memo = TypeCheckMemo(globalns, localns | walker.collected) else: - walker.visit(ref_ast) - memo = TypeCheckMemo(globalns, localns | walker.collected) + memo = TypeCheckMemo(globalns, localns) + try: check_type_internal(var, ref, memo) except TypeCheckError as e: e.args = (f"({adapter.id}) " + e.args[0],) + e.args[1:] raise - except TypeError as e: - if "is not subscriptable" not in e.args[0]: - raise - assert isinstance(ref_ast.body, ast.Subscript) - # When type reference is a specialized class, we - # have to concede by verifying unsubscripted type, - # as specialized class is a stub-only thing here. - # Lxml runtime does not support __class_getitem__ - # - # FIXME: Only the simplest, unnested subscript supported. - # Need some work for more complex ones. - bare_type = ast.unparse(ref_ast.body.value) - try: - check_type_internal(var, _t.ForwardRef(bare_type), memo) - except TypeCheckError as e: - e.args = (f"({adapter.id}) " + e.args[0],) + e.args[1:] - raise return var