Skip to content

Commit

Permalink
feat(test): Move subscript transform to each type checker adapter
Browse files Browse the repository at this point in the history
With subscript transformation handled by AST node transformer, it is now possible to handle nested subscript in type expression
  • Loading branch information
abelcheung committed Oct 20, 2024
1 parent 7ba3222 commit 95bf2e0
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 38 deletions.
26 changes: 24 additions & 2 deletions test-rt/_testutils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import importlib
import pathlib
import re
from typing import Any, ClassVar, ForwardRef, Iterable, NamedTuple
from typing import Any, ClassVar, ForwardRef, Iterable, NamedTuple, cast

from lxml.etree import LXML_VERSION

Expand Down Expand Up @@ -39,7 +39,7 @@ def __str__(self) -> str:
return "{}(): {}".format(self._func, self.args[0])


class NameCollectorBase(ast.NodeVisitor):
class NameCollectorBase(ast.NodeTransformer):
def __init__(
self,
globalns: dict[str, Any],
Expand All @@ -49,11 +49,33 @@ def __init__(
self._globalns = globalns
self._localns = localns
self.modified: bool = False
# typing_extensions guaranteed to be present,
# as a dependency of typeguard
self.collected: dict[str, Any] = {
m: importlib.import_module(m)
for m in ("builtins", "typing", "typing_extensions")
}

def visit_Subscript(self, node: ast.Subscript) -> ast.expr:
node.value = cast("ast.expr", self.visit(node.value))
node.slice = cast("ast.expr", self.visit(node.slice))

# When type reference is a stub-only specialized class
# which don't have runtime support (lxml classes have
# no __class_getitem__), concede by verifying
# unsubscripted type.
try:
eval(ast.unparse(node), self._globalns, self._localns | self.collected)
except TypeError as e:
if "is not subscriptable" not in e.args[0]:
raise
# TODO Insert node.value dependent hook for extra
# varification of subscript type
self.modified = True
return node.value
else:
return node


class TypeCheckerAdapterBase:
id: ClassVar[str]
Expand Down
14 changes: 7 additions & 7 deletions test-rt/_testutils/mypy_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class _MypyDiagObj(_t.TypedDict):
message: str


class _NameCollector(NameCollectorBase, ast.NodeTransformer):
def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
class _NameCollector(NameCollectorBase):
def visit_Attribute(self, node: ast.Attribute) -> ast.expr:
prefix = ast.unparse(node.value)
name = node.attr

Expand All @@ -49,9 +49,9 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
self.modified = True
return ast.Name(id=name, ctx=node.ctx)

node = _t.cast("ast.Attribute", self.generic_visit(node))
_ = self.visit(node.value)

if (resolved := getattr(self.collected[prefix], name, False)):
if resolved := getattr(self.collected[prefix], name, False):
self.collected[ast.unparse(node)] = resolved
return node

Expand All @@ -70,7 +70,7 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
# but with a few exceptions (like tuple, Union).
# visit_Attribute can ultimately recurse into visit_Name
# as well
def visit_Name(self, node: ast.Name) -> ast.AST:
def visit_Name(self, node: ast.Name) -> ast.Name:
name = node.id
try:
eval(name, self._globalns, self._localns | self.collected)
Expand All @@ -96,10 +96,10 @@ def visit_Name(self, node: ast.Name) -> ast.AST:
# For class defined inside local function scope, mypy outputs
# something like "test_elem_class_lookup.FooClass@97".
# Return only the left operand after processing.
def visit_BinOp(self, node: ast.BinOp) -> ast.AST:
def visit_BinOp(self, node: ast.BinOp) -> ast.expr:
if isinstance(node.op, ast.MatMult) and isinstance(node.right, ast.Constant):
# Mypy disallows returning Any
return _t.cast("ast.AST", self.visit(node.left))
return _t.cast("ast.expr", self.visit(node.left))
# For expression that haven't been accounted for, just don't
# process and allow name resolution to fail
return node
Expand Down
9 changes: 4 additions & 5 deletions test-rt/_testutils/pyright_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,17 @@
class _NameCollector(NameCollectorBase):
# Pyright inferred type results always contain bare names only,
# so don't need to bother with visit_Attribute()
def visit_Name(self, node: ast.Name) -> None:
def visit_Name(self, node: ast.Name) -> ast.Name:
name = node.id
try:
eval(name, self._globalns, self._localns | self.collected)
except NameError:
for m in ("typing", "typing_extensions"):
if hasattr(self.collected[m], name):
self.collected[name] = getattr(self.collected[m], name)
break
else:
raise

return node
raise
return node

class _TypeCheckerAdapter(TypeCheckerAdapterBase):
id = "pyright"
Expand Down
33 changes: 9 additions & 24 deletions test-rt/_testutils/rt_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,37 +105,22 @@ def reveal_type_wrapper(var: _T) -> _T:
adapter.typechecker_result[pos] = VarType(var_name, tc_result.type)

ref = tc_result.type
ref_ast = ast.parse(ref.__forward_arg__, mode="eval")
walker = adapter.create_collector(globalns, localns)
if isinstance(walker, ast.NodeTransformer):
try:
_ = eval(ref.__forward_arg__, globalns, localns)
except:
ref_ast = ast.parse(ref.__forward_arg__, mode="eval")
walker = adapter.create_collector(globalns, localns)
new_ast = walker.visit(ref_ast)
if walker.modified:
ref_ast = ast.fix_missing_locations(new_ast)
ref = _t.ForwardRef(ast.unparse(ref_ast))
ref = _t.ForwardRef(ast.unparse(new_ast))
memo = TypeCheckMemo(globalns, localns | walker.collected)
else:
walker.visit(ref_ast)
memo = TypeCheckMemo(globalns, localns | walker.collected)
memo = TypeCheckMemo(globalns, localns)

try:
check_type_internal(var, ref, memo)
except TypeCheckError as e:
e.args = (f"({adapter.id}) " + e.args[0],) + e.args[1:]
raise
except TypeError as e:
if "is not subscriptable" not in e.args[0]:
raise
assert isinstance(ref_ast.body, ast.Subscript)
# When type reference is a specialized class, we
# have to concede by verifying unsubscripted type,
# as specialized class is a stub-only thing here.
# Lxml runtime does not support __class_getitem__
#
# FIXME: Only the simplest, unnested subscript supported.
# Need some work for more complex ones.
bare_type = ast.unparse(ref_ast.body.value)
try:
check_type_internal(var, _t.ForwardRef(bare_type), memo)
except TypeCheckError as e:
e.args = (f"({adapter.id}) " + e.args[0],) + e.args[1:]
raise

return var

0 comments on commit 95bf2e0

Please sign in to comment.