From 194914fee8fce3abe54dcc881332263faad11e88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sondre=20Lilleb=C3=B8=20Gundersen?= Date: Fri, 25 Jun 2021 12:50:50 +0200 Subject: [PATCH] Register annotations for async function defs --- flake8_type_checking/checker.py | 10 +++++++--- tests/test_tc004.py | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/flake8_type_checking/checker.py b/flake8_type_checking/checker.py index 3326b58..b6ade54 100644 --- a/flake8_type_checking/checker.py +++ b/flake8_type_checking/checker.py @@ -19,7 +19,6 @@ # to handle and which values to return possible_local_errors += (AppRegistryNotReady, ImproperlyConfigured) # type: ignore - if TYPE_CHECKING: from typing import Any, Generator, Optional, Union @@ -268,7 +267,8 @@ def _add_annotation(self, node: ast.AST) -> None: elif isinstance(node, ast.BinOp): return - def _set_child_node_attribute(self, node: Any, attr: str, val: Any) -> Any: + @staticmethod + def _set_child_node_attribute(node: Any, attr: str, val: Any) -> Any: # Set the parent attribute on the current node children for key, value in node.__dict__.items(): if type(value) not in [int, str, list, bool] and value is not None: @@ -298,7 +298,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None: if getattr(node, 'value', None): self.generic_visit(node.value) # type: ignore - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]) -> None: """Remove and map function arguments and returns.""" # Map annotations for path in [node.args.args, node.args.kwonlyargs]: @@ -331,6 +331,10 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: self.generic_visit(node) + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + """Remove and map function arguments and returns.""" + self.visit_FunctionDef(node) + class TypingOnlyImportsChecker: """Checks for imports exclusively used by type annotation elements.""" diff --git a/tests/test_tc004.py b/tests/test_tc004.py index 4705f66..cf290d6 100644 --- a/tests/test_tc004.py +++ b/tests/test_tc004.py @@ -84,6 +84,25 @@ def example(): ), set(), ), + ( + textwrap.dedent( + """ + from __future__ import annotations + + from typing import TYPE_CHECKING + + if TYPE_CHECKING: + from typing import AsyncIterator, List + + + class Example: + + async def example(self) -> AsyncIterator[List[str]]: + yield 0 + """ + ), + set(), + ), ]