From a6361ea46b4a1ee56797ca33d32e596289d3945e Mon Sep 17 00:00:00 2001 From: Daverball Date: Fri, 13 Dec 2024 12:55:06 +0100 Subject: [PATCH] fix: Avoid false negatives for TC001-003 related to `typing.cast`. --- flake8_type_checking/checker.py | 44 ++++++++++++++++++++++++++++++--- tests/test_tc001_to_tc003.py | 24 ++++++++++++++++++ 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/flake8_type_checking/checker.py b/flake8_type_checking/checker.py index d866926..39ae83c 100644 --- a/flake8_type_checking/checker.py +++ b/flake8_type_checking/checker.py @@ -992,6 +992,31 @@ def visit_annotated_value(self, node: ast.expr) -> None: self.import_visitor.in_soft_use_context = previous_context +class CastTypeExpressionVisitor(AnnotationVisitor): + """Visit a cast type expression and collect all the quoted names.""" + + def __init__(self, typing_lookup: SupportsIsTyping) -> None: + #: All the quoted_names referenced inside the type expression + self.quoted_names: set[str] = set() + self._typing_lookup = typing_lookup + + def is_typing(self, node: ast.AST, symbol: str) -> bool: + """Check if the given node matches the given typing symbol.""" + return self._typing_lookup.is_typing(node, symbol) + + def visit_annotation_name(self, node: ast.Name) -> None: + """Ignore visited names.""" + # We could either record them as quoted names pre-emptively or + # as uses, but neither seems ideal, let's just skip these names + # as we have previously. + + def visit_annotation_string(self, node: ast.Constant) -> None: + """Collect all the names referenced inside the forward reference.""" + visitor = StringAnnotationVisitor(self._typing_lookup) + visitor.parse_and_visit_string_annotation(node.value) + self.quoted_names.update(visitor.names) + + class ImportVisitor( DunderAllMixin, FunctoolsSingledispatchMixin, @@ -1081,6 +1106,10 @@ def __init__( #: Where typing.cast() is called with an unquoted type. self.unquoted_types_in_casts: list[tuple[int, int, str]] = [] + #: All forward referenced names used in cast type expressions + # we need to track this in order to avoid false negatives for TC001-003 + self.quoted_type_names_in_casts: set[str] = set() + #: For tracking which comprehension/IfExp we're currently inside of self.active_context: Comprehension | ast.IfExp | None = None @@ -1895,6 +1924,10 @@ def register_unquoted_type_in_typing_cast(self, node: ast.Call) -> None: arg = node.args[0] + visitor = CastTypeExpressionVisitor(self) + visitor.visit(arg) + self.quoted_type_names_in_casts.update(visitor.quoted_names) + if isinstance(arg, ast.Constant) and isinstance(arg.value, str): return # Type argument is already a string literal. @@ -1999,10 +2032,13 @@ def unused_imports(self) -> Flake8Generator: unused_imports = all_imports - self.visitor.names - self.visitor.soft_uses used_imports = all_imports - unused_imports already_imported_modules = [self.visitor.imports[name].module for name in used_imports] - annotation_names = ( - [n for i in self.visitor.wrapped_annotations for n in i.names] - + [i.annotation for i in self.visitor.unwrapped_annotations] - + [n for i in self.visitor.excess_wrapped_annotations for n in i.names] + annotation_names = list( + chain( + (n for i in self.visitor.wrapped_annotations for n in i.names), + (i.annotation for i in self.visitor.unwrapped_annotations), + (n for i in self.visitor.excess_wrapped_annotations for n in i.names), + self.visitor.quoted_type_names_in_casts, + ) ) for name in unused_imports: diff --git a/tests/test_tc001_to_tc003.py b/tests/test_tc001_to_tc003.py index e0dcd57..ec097b1 100644 --- a/tests/test_tc001_to_tc003.py +++ b/tests/test_tc001_to_tc003.py @@ -323,6 +323,30 @@ def example() -> Any: ), set(), ), + # Issue #127 + ( + textwrap.dedent( + f''' + from {import_} import Foo + from typing import Any, cast + + a = cast('Foo', 1) + ''' + ), + {'2:0 ' + ERROR.format(module=f'{import_}.Foo')}, + ), + # forward reference in sub-expression of cast type + ( + textwrap.dedent( + f''' + from {import_} import Foo + from typing import Any, cast + + a = cast(list['Foo'], 1) + ''' + ), + {'2:0 ' + ERROR.format(module=f'{import_}.Foo')}, + ), ] return [