Skip to content

Commit

Permalink
refactor: Move common function annotation walking into utility function
Browse files Browse the repository at this point in the history
  • Loading branch information
Daverball committed Dec 13, 2024
1 parent b3c2c4b commit 7ce3b36
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 33 deletions.
44 changes: 11 additions & 33 deletions flake8_type_checking/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
builtin_names,
sqlalchemy_default_mapped_dotted_names,
)
from flake8_type_checking.util import iter_function_annotation_nodes

if TYPE_CHECKING:
from _ast import AsyncFunctionDef, FunctionDef
Expand Down Expand Up @@ -286,24 +287,14 @@ def _function_is_wrapped_by_validate_arguments(self, node: FunctionDef | AsyncFu
def visit_FunctionDef(self, node: FunctionDef) -> None:
"""Remove and map function arguments and returns."""
if self._function_is_wrapped_by_validate_arguments(node):
for path in [node.args.args, node.args.kwonlyargs, node.args.posonlyargs]:
for argument in path:
if hasattr(argument, 'annotation') and argument.annotation:
self.visit(argument.annotation)

if node.returns:
self.visit(node.returns)
for expr in iter_function_annotation_nodes(node):
self.visit(expr)

def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
"""Remove and map function arguments and returns."""
if self._function_is_wrapped_by_validate_arguments(node):
for path in [node.args.args, node.args.kwonlyargs, node.args.posonlyargs]:
for argument in path:
if hasattr(argument, 'annotation') and argument.annotation:
self.visit(argument.annotation)

if node.returns:
self.visit(node.returns)
for expr in iter_function_annotation_nodes(node):
self.visit(expr)


class SQLAlchemyAnnotationVisitor(AnnotationVisitor):
Expand Down Expand Up @@ -554,13 +545,8 @@ def handle_fastapi_decorator(self, node: AsyncFunctionDef | FunctionDef) -> None
To achieve this, we just visit the annotations to register them as "uses".
"""
for path in [node.args.args, node.args.kwonlyargs, node.args.posonlyargs]:
for argument in path:
if hasattr(argument, 'annotation') and argument.annotation:
self.visit(argument.annotation)

if node.returns:
self.visit(node.returns)
for expr in iter_function_annotation_nodes(node):
self.visit(expr)


class FunctoolsSingledispatchMixin:
Expand Down Expand Up @@ -600,13 +586,15 @@ def visit_FunctionDef(self, node: FunctionDef) -> None:
"""Remove and map function arguments and returns."""
super().visit_FunctionDef(node) # type: ignore[misc]
if self.has_singledispatch_decorator(node):
self.handle_singledispatch_decorator(node)
for expr in iter_function_annotation_nodes(node):
self.visit(expr)

def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
"""Remove and map function arguments and returns."""
super().visit_AsyncFunctionDef(node) # type: ignore[misc]
if self.has_singledispatch_decorator(node):
self.handle_singledispatch_decorator(node)
for expr in iter_function_annotation_nodes(node):
self.visit(expr)

def has_singledispatch_decorator(self, node: FunctionDef | AsyncFunctionDef) -> bool:
"""Determine whether this function is decorated with `functools.singledispatch`."""
Expand All @@ -615,16 +603,6 @@ def has_singledispatch_decorator(self, node: FunctionDef | AsyncFunctionDef) ->
for decorator_node in node.decorator_list
)

def handle_singledispatch_decorator(self, node: FunctionDef | AsyncFunctionDef) -> None:
"""Walk all the annotations to register them as runtime uses."""
for path in [node.args.args, node.args.kwonlyargs, node.args.posonlyargs]:
for argument in path:
if hasattr(argument, 'annotation') and argument.annotation:
self.visit(argument.annotation)

if node.returns:
self.visit(node.returns)


@dataclass
class ImportName:
Expand Down
23 changes: 23 additions & 0 deletions flake8_type_checking/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

from itertools import chain
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import ast
from ast import AsyncFunctionDef, FunctionDef
from collections.abc import Iterator


def iter_function_annotation_nodes(node: AsyncFunctionDef | FunctionDef) -> Iterator[ast.expr]:
"""Yield all the annotation expression nodes inside the given function node."""
for arg in chain(node.args.args, node.args.kwonlyargs, node.args.posonlyargs):
if arg.annotation:
yield arg.annotation

for opt_arg in (node.args.kwarg, node.args.vararg):
if opt_arg and opt_arg.annotation:
yield opt_arg.annotation

if node.returns:
yield node.returns

0 comments on commit 7ce3b36

Please sign in to comment.