Skip to content

Commit

Permalink
Register annotations for async function defs
Browse files Browse the repository at this point in the history
  • Loading branch information
sondrelg committed Jun 25, 2021
1 parent 48e383f commit 194914f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
10 changes: 7 additions & 3 deletions flake8_type_checking/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# to handle and which values to return
possible_local_errors += (AppRegistryNotReady, ImproperlyConfigured) # type: ignore


if TYPE_CHECKING:
from typing import Any, Generator, Optional, Union

Expand Down Expand Up @@ -268,7 +267,8 @@ def _add_annotation(self, node: ast.AST) -> None:
elif isinstance(node, ast.BinOp):
return

def _set_child_node_attribute(self, node: Any, attr: str, val: Any) -> Any:
@staticmethod
def _set_child_node_attribute(node: Any, attr: str, val: Any) -> Any:
# Set the parent attribute on the current node children
for key, value in node.__dict__.items():
if type(value) not in [int, str, list, bool] and value is not None:
Expand Down Expand Up @@ -298,7 +298,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
if getattr(node, 'value', None):
self.generic_visit(node.value) # type: ignore

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
def visit_FunctionDef(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]) -> None:
"""Remove and map function arguments and returns."""
# Map annotations
for path in [node.args.args, node.args.kwonlyargs]:
Expand Down Expand Up @@ -331,6 +331,10 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:

self.generic_visit(node)

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
"""Remove and map function arguments and returns."""
self.visit_FunctionDef(node)


class TypingOnlyImportsChecker:
"""Checks for imports exclusively used by type annotation elements."""
Expand Down
19 changes: 19 additions & 0 deletions tests/test_tc004.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,25 @@ def example():
),
set(),
),
(
textwrap.dedent(
"""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import AsyncIterator, List
class Example:
async def example(self) -> AsyncIterator[List[str]]:
yield 0
"""
),
set(),
),
]


Expand Down

0 comments on commit 194914f

Please sign in to comment.