diff --git a/flake8_type_checking/checker.py b/flake8_type_checking/checker.py index 1dee8f3..d98e883 100644 --- a/flake8_type_checking/checker.py +++ b/flake8_type_checking/checker.py @@ -38,6 +38,7 @@ builtin_names, sqlalchemy_default_mapped_dotted_names, ) +from flake8_type_checking.util import iter_function_annotation_nodes if TYPE_CHECKING: from _ast import AsyncFunctionDef, FunctionDef @@ -286,24 +287,14 @@ def _function_is_wrapped_by_validate_arguments(self, node: FunctionDef | AsyncFu def visit_FunctionDef(self, node: FunctionDef) -> None: """Remove and map function arguments and returns.""" if self._function_is_wrapped_by_validate_arguments(node): - for path in [node.args.args, node.args.kwonlyargs, node.args.posonlyargs]: - for argument in path: - if hasattr(argument, 'annotation') and argument.annotation: - self.visit(argument.annotation) - - if node.returns: - self.visit(node.returns) + for expr in iter_function_annotation_nodes(node): + self.visit(expr) def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None: """Remove and map function arguments and returns.""" if self._function_is_wrapped_by_validate_arguments(node): - for path in [node.args.args, node.args.kwonlyargs, node.args.posonlyargs]: - for argument in path: - if hasattr(argument, 'annotation') and argument.annotation: - self.visit(argument.annotation) - - if node.returns: - self.visit(node.returns) + for expr in iter_function_annotation_nodes(node): + self.visit(expr) class SQLAlchemyAnnotationVisitor(AnnotationVisitor): @@ -554,13 +545,8 @@ def handle_fastapi_decorator(self, node: AsyncFunctionDef | FunctionDef) -> None To achieve this, we just visit the annotations to register them as "uses". """ - for path in [node.args.args, node.args.kwonlyargs, node.args.posonlyargs]: - for argument in path: - if hasattr(argument, 'annotation') and argument.annotation: - self.visit(argument.annotation) - - if node.returns: - self.visit(node.returns) + for expr in iter_function_annotation_nodes(node): + self.visit(expr) class FunctoolsSingledispatchMixin: @@ -600,13 +586,15 @@ def visit_FunctionDef(self, node: FunctionDef) -> None: """Remove and map function arguments and returns.""" super().visit_FunctionDef(node) # type: ignore[misc] if self.has_singledispatch_decorator(node): - self.handle_singledispatch_decorator(node) + for expr in iter_function_annotation_nodes(node): + self.visit(expr) def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None: """Remove and map function arguments and returns.""" super().visit_AsyncFunctionDef(node) # type: ignore[misc] if self.has_singledispatch_decorator(node): - self.handle_singledispatch_decorator(node) + for expr in iter_function_annotation_nodes(node): + self.visit(expr) def has_singledispatch_decorator(self, node: FunctionDef | AsyncFunctionDef) -> bool: """Determine whether this function is decorated with `functools.singledispatch`.""" @@ -615,16 +603,6 @@ def has_singledispatch_decorator(self, node: FunctionDef | AsyncFunctionDef) -> for decorator_node in node.decorator_list ) - def handle_singledispatch_decorator(self, node: FunctionDef | AsyncFunctionDef) -> None: - """Walk all the annotations to register them as runtime uses.""" - for path in [node.args.args, node.args.kwonlyargs, node.args.posonlyargs]: - for argument in path: - if hasattr(argument, 'annotation') and argument.annotation: - self.visit(argument.annotation) - - if node.returns: - self.visit(node.returns) - @dataclass class ImportName: diff --git a/flake8_type_checking/util.py b/flake8_type_checking/util.py new file mode 100644 index 0000000..ad00c43 --- /dev/null +++ b/flake8_type_checking/util.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from itertools import chain +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import ast + from ast import AsyncFunctionDef, FunctionDef + from collections.abc import Iterator + + +def iter_function_annotation_nodes(node: AsyncFunctionDef | FunctionDef) -> Iterator[ast.expr]: + """Yield all the annotation expression nodes inside the given function node.""" + for arg in chain(node.args.args, node.args.kwonlyargs, node.args.posonlyargs): + if arg.annotation: + yield arg.annotation + + for opt_arg in (node.args.kwarg, node.args.vararg): + if opt_arg and opt_arg.annotation: + yield opt_arg.annotation + + if node.returns: + yield node.returns