Skip to content

Commit

Permalink
fix: Avoid false negatives for TC001-003 related to typing.cast.
Browse files Browse the repository at this point in the history
  • Loading branch information
Daverball authored Dec 13, 2024
1 parent eef57df commit 87543b7
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 4 deletions.
44 changes: 40 additions & 4 deletions flake8_type_checking/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,31 @@ def visit_annotated_value(self, node: ast.expr) -> None:
self.import_visitor.in_soft_use_context = previous_context


class CastTypeExpressionVisitor(AnnotationVisitor):
"""Visit a cast type expression and collect all the quoted names."""

def __init__(self, typing_lookup: SupportsIsTyping) -> None:
#: All the quoted_names referenced inside the type expression
self.quoted_names: set[str] = set()
self._typing_lookup = typing_lookup

def is_typing(self, node: ast.AST, symbol: str) -> bool:
"""Check if the given node matches the given typing symbol."""
return self._typing_lookup.is_typing(node, symbol)

def visit_annotation_name(self, node: ast.Name) -> None:
"""Ignore visited names."""
# We could either record them as quoted names pre-emptively or
# as uses, but neither seems ideal, let's just skip these names
# as we have previously.

def visit_annotation_string(self, node: ast.Constant) -> None:
"""Collect all the names referenced inside the forward reference."""
visitor = StringAnnotationVisitor(self._typing_lookup)
visitor.parse_and_visit_string_annotation(node.value)
self.quoted_names.update(visitor.names)


class ImportVisitor(
DunderAllMixin,
FunctoolsSingledispatchMixin,
Expand Down Expand Up @@ -1081,6 +1106,10 @@ def __init__(
#: Where typing.cast() is called with an unquoted type.
self.unquoted_types_in_casts: list[tuple[int, int, str]] = []

#: All forward referenced names used in cast type expressions
# we need to track this in order to avoid false negatives for TC001-003
self.quoted_type_names_in_casts: set[str] = set()

#: For tracking which comprehension/IfExp we're currently inside of
self.active_context: Comprehension | ast.IfExp | None = None

Expand Down Expand Up @@ -1895,6 +1924,10 @@ def register_unquoted_type_in_typing_cast(self, node: ast.Call) -> None:

arg = node.args[0]

visitor = CastTypeExpressionVisitor(self)
visitor.visit(arg)
self.quoted_type_names_in_casts.update(visitor.quoted_names)

if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
return # Type argument is already a string literal.

Expand Down Expand Up @@ -1999,10 +2032,13 @@ def unused_imports(self) -> Flake8Generator:
unused_imports = all_imports - self.visitor.names - self.visitor.soft_uses
used_imports = all_imports - unused_imports
already_imported_modules = [self.visitor.imports[name].module for name in used_imports]
annotation_names = (
[n for i in self.visitor.wrapped_annotations for n in i.names]
+ [i.annotation for i in self.visitor.unwrapped_annotations]
+ [n for i in self.visitor.excess_wrapped_annotations for n in i.names]
annotation_names = list(
chain(
(n for i in self.visitor.wrapped_annotations for n in i.names),
(i.annotation for i in self.visitor.unwrapped_annotations),
(n for i in self.visitor.excess_wrapped_annotations for n in i.names),
self.visitor.quoted_type_names_in_casts,
)
)

for name in unused_imports:
Expand Down
24 changes: 24 additions & 0 deletions tests/test_tc001_to_tc003.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,30 @@ def example() -> Any:
),
set(),
),
# Issue #127
(
textwrap.dedent(
f'''
from {import_} import Foo
from typing import Any, cast
a = cast('Foo', 1)
'''
),
{'2:0 ' + ERROR.format(module=f'{import_}.Foo')},
),
# forward reference in sub-expression of cast type
(
textwrap.dedent(
f'''
from {import_} import Foo
from typing import Any, cast
a = cast(list['Foo'], 1)
'''
),
{'2:0 ' + ERROR.format(module=f'{import_}.Foo')},
),
]

return [
Expand Down

0 comments on commit 87543b7

Please sign in to comment.