From 3a3650615483dd058a4870acdaf72f5fc8dd0b7c Mon Sep 17 00:00:00 2001 From: Daverball Date: Mon, 9 Dec 2024 08:54:16 +0100 Subject: [PATCH 1/3] refactor: Make use of `is_typing` for `Literal`/`Annotated` detection. Makes soft use concept of `mapped_names` more generic and reuse it for walking `Annotated` value expressions. --- flake8_type_checking/checker.py | 99 +++++++++++++++++++++++++-------- flake8_type_checking/types.py | 4 ++ tests/test_name_extraction.py | 54 +++++++++++------- tests/test_sqlalchemy.py | 19 +++++++ tests/test_tc001_to_tc003.py | 6 +- 5 files changed, 138 insertions(+), 44 deletions(-) diff --git a/flake8_type_checking/checker.py b/flake8_type_checking/checker.py index aceabd9..7eb1d15 100644 --- a/flake8_type_checking/checker.py +++ b/flake8_type_checking/checker.py @@ -66,12 +66,17 @@ def ast_unparse(node: ast.AST) -> str: Import, ImportTypeValue, Name, + SupportsIsTyping, ) class AnnotationVisitor(ABC): """Simplified node visitor for traversing annotations.""" + @abstractmethod + def is_typing(self, node: ast.AST, symbol: str) -> bool: + """Check if the given node matches the given typing symbol.""" + @abstractmethod def visit_annotation_name(self, node: ast.Name) -> None: """Visit a name inside an annotation.""" @@ -80,6 +85,10 @@ def visit_annotation_name(self, node: ast.Name) -> None: def visit_annotation_string(self, node: ast.Constant) -> None: """Visit a string constant inside an annotation.""" + def visit_annotated_value(self, node: ast.expr) -> None: + """Visit a value expression of `typing.Annotated[type, value]`.""" + return + def visit(self, node: ast.AST) -> None: """Visit relevant child nodes on an annotation.""" if node is None: @@ -95,15 +104,19 @@ def visit(self, node: ast.AST) -> None: self.visit(node.value) elif isinstance(node, ast.Subscript): self.visit(node.value) - if getattr(node.value, 'id', '') == 'Annotated' and isinstance( + if self.is_typing(node.value, 'Literal'): + return + elif self.is_typing(node.value, 'Annotated') and isinstance( (elts_node := node.slice.value if py38 and isinstance(node.slice, Index) else node.slice), (ast.Tuple, ast.List), ): if elts_node.elts: - # only visit the first element - self.visit(elts_node.elts[0]) - # TODO: We may want to visit the rest as a soft-runtime use - elif getattr(node.value, 'id', '') != 'Literal': + elts_iter = iter(elts_node.elts) + # only visit the first element like a type expression + self.visit(next(elts_iter)) + for value_node in elts_iter: + self.visit_annotated_value(value_node) + else: self.visit(node.slice) elif isinstance(node, (ast.Tuple, ast.List)): for n in node.elts: @@ -300,18 +313,29 @@ def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None: class SQLAlchemyAnnotationVisitor(AnnotationVisitor): """Adds any names in the annotation to mapped names.""" - def __init__(self, mapped_names: set[str]) -> None: - self.mapped_names = mapped_names + def __init__(self, sqlalchemy_plugin: SQLAlchemyMixin) -> None: + self.plugin = sqlalchemy_plugin + + def is_typing(self, node: ast.AST, symbol: str) -> bool: + """Check if the given node matches the given typing symbol.""" + return self.plugin.is_typing(node, symbol) + + def visit_annotated_value(self, node: ast.expr) -> None: + """Visit a value expression of `typing.Annotated[type, value]`.""" + previous_context = self.plugin.in_soft_use_context + self.plugin.in_soft_use_context = True + self.plugin.visit(node) + self.plugin.in_soft_use_context = previous_context def visit_annotation_name(self, node: ast.Name) -> None: """Add name to mapped names.""" - self.mapped_names.add(node.id) + self.plugin.soft_uses.add(node.id) def visit_annotation_string(self, node: ast.Constant) -> None: """Add all the names in the string to mapped names.""" - visitor = StringAnnotationVisitor() + visitor = StringAnnotationVisitor(self.plugin) visitor.parse_and_visit_string_annotation(node.value) - self.mapped_names.update(visitor.names) + self.plugin.soft_uses.update(visitor.names) class SQLAlchemyMixin: @@ -333,6 +357,8 @@ class SQLAlchemyMixin: sqlalchemy_mapped_dotted_names: set[str] current_scope: Scope uses: dict[str, list[tuple[ast.AST, Scope]]] + soft_uses: set[str] + in_soft_use_context: bool def in_type_checking_block(self, lineno: int, col_offset: int) -> bool: # noqa: D102 ... @@ -340,16 +366,17 @@ def in_type_checking_block(self, lineno: int, col_offset: int) -> bool: # noqa: def lookup_full_name(self, node: ast.AST) -> str | None: # noqa: D102 ... + def is_typing(self, node: ast.AST, symbol: str) -> bool: # noqa: D102 + ... + + def visit(self, node: ast.AST) -> ast.AST: # noqa: D102 + ... + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - #: Contains a set of all names used inside `Mapped[...]` annotations - # These are treated like soft-uses, i.e. we don't know if it will be - # used at runtime or not - self.mapped_names: set[str] = set() - #: Used for visiting annotations - self.sqlalchemy_annotation_visitor = SQLAlchemyAnnotationVisitor(self.mapped_names) + self.sqlalchemy_annotation_visitor = SQLAlchemyAnnotationVisitor(self) def visit_AnnAssign(self, node: ast.AnnAssign) -> None: """Remove all annotations assigments.""" @@ -403,9 +430,9 @@ def handle_sqlalchemy_annotation(self, node: ast.AST) -> None: # add all names contained in the inner part of the annotation # since this is not as strict as an actual runtime use, we don't # care if we record too much here - visitor = StringAnnotationVisitor() + visitor = StringAnnotationVisitor(self) visitor.parse_and_visit_string_annotation(inner) - self.mapped_names.update(visitor.names) + self.soft_uses.update(visitor.names) return # we only need to handle annotations like `Mapped[...]` @@ -780,9 +807,14 @@ def lookup(self, symbol_name: str, use: HasPosition | None = None, runtime_only: class StringAnnotationVisitor(AnnotationVisitor): """Visit a parsed string annotation and collect all the names.""" - def __init__(self) -> None: + def __init__(self, typing_lookup: SupportsIsTyping) -> None: #: All the names referenced inside the annotation self.names: set[str] = set() + self._typing_lookup = typing_lookup + + def is_typing(self, node: ast.AST, symbol: str) -> bool: + """Check if the given node matches the given typing symbol.""" + return self._typing_lookup.is_typing(node, symbol) def parse_and_visit_string_annotation(self, annotation: str) -> None: """Parse and visit the given string as an annotation expression.""" @@ -814,7 +846,7 @@ def visit_annotation_string(self, node: ast.Constant) -> None: class ImportAnnotationVisitor(AnnotationVisitor): """Map all annotations on an AST node.""" - def __init__(self) -> None: + def __init__(self, import_visitor: ImportVisitor) -> None: #: All type annotations in the file, without quotes around them self.unwrapped_annotations: list[UnwrappedAnnotation] = [] @@ -831,6 +863,12 @@ def __init__(self) -> None: # e.g. AnnAssign.annotation within a function body will never evaluate self.never_evaluates = False + self.import_visitor = import_visitor + + def is_typing(self, node: ast.AST, symbol: str) -> bool: + """Check if the given node matches the given typing symbol.""" + return self.import_visitor.is_typing(node, symbol) + def visit( self, node: ast.AST, @@ -864,12 +902,19 @@ def visit_annotation_string(self, node: ast.Constant) -> None: if getattr(node, BINOP_OPERAND_PROPERTY, False): self.invalid_binop_literals.append(node) else: - visitor = StringAnnotationVisitor() + visitor = StringAnnotationVisitor(self.import_visitor) visitor.parse_and_visit_string_annotation(node.value) (self.excess_wrapped_annotations if self.never_evaluates else self.wrapped_annotations).append( WrappedAnnotation(node.lineno, node.col_offset, node.value, visitor.names, self.scope, self.type) ) + def visit_annotated_value(self, node: ast.expr) -> None: + """Visit a value expression of `typing.Annotated[type, value]`.""" + previous_context = self.import_visitor.in_soft_use_context + self.import_visitor.in_soft_use_context = True + self.import_visitor.visit(node) + self.import_visitor.in_soft_use_context = previous_context + class ImportVisitor( DunderAllMixin, AttrsMixin, InjectorMixin, FastAPIMixin, PydanticMixin, SQLAlchemyMixin, ast.NodeVisitor @@ -932,8 +977,12 @@ def __init__( #: List of all names and ids, except type declarations self.uses: dict[str, list[tuple[ast.AST, Scope]]] = defaultdict(list) + #: Contains a set of all names to be treated like soft-uses. + # i.e. we don't know if it will be used at runtime or not + self.soft_uses: set[str] = set() + #: Handles logic of visiting annotation nodes - self.annotation_visitor = ImportAnnotationVisitor() + self.annotation_visitor = ImportAnnotationVisitor(self) #: Whether there is a `from __futures__ import annotations` is present in the file self.futures_annotation: Optional[bool] = None @@ -951,6 +1000,10 @@ def __init__( #: For tracking which comprehension/IfExp we're currently inside of self.active_context: Optional[Comprehension | ast.IfExp] = None + #: Whether or not we're in a context where uses count as soft-uses. + # E.g. the value expression of `typing.Annotated[type, value]` + self.in_soft_use_context: bool = False + @contextmanager def create_scope(self, node: ast.ClassDef | Function, is_head: bool = True) -> Iterator[Scope]: """Create a new scope.""" @@ -1853,7 +1906,7 @@ def unused_imports(self) -> Flake8Generator: } all_imports = {name for name, imp in self.visitor.imports.items() if not imp.exempt} - unused_imports = all_imports - self.visitor.names - self.visitor.mapped_names + unused_imports = all_imports - self.visitor.names - self.visitor.soft_uses used_imports = all_imports - unused_imports already_imported_modules = [self.visitor.imports[name].module for name in used_imports] annotation_names = ( diff --git a/flake8_type_checking/types.py b/flake8_type_checking/types.py index 8e2af34..85db8af 100644 --- a/flake8_type_checking/types.py +++ b/flake8_type_checking/types.py @@ -24,5 +24,9 @@ def lineno(self) -> int: def col_offset(self) -> int: pass + class SupportsIsTyping(Protocol): + def is_typing(self, node: ast.AST, symbol: str) -> bool: + pass + ImportTypeValue = Literal['APPLICATION', 'THIRD_PARTY', 'BUILTIN', 'FUTURE'] diff --git a/tests/test_name_extraction.py b/tests/test_name_extraction.py index 1739756..88689d6 100644 --- a/tests/test_name_extraction.py +++ b/tests/test_name_extraction.py @@ -1,38 +1,52 @@ +import ast import sys import pytest -from flake8_type_checking.checker import StringAnnotationVisitor +from flake8_type_checking.checker import ImportVisitor, StringAnnotationVisitor examples = [ - ('', set()), - ('invalid_syntax]', set()), - ('int', {'int'}), - ('dict[str, int]', {'dict', 'str', 'int'}), + ('', set(), set()), + ('invalid_syntax]', set(), set()), + ('int', {'int'}, set()), + ('dict[str, int]', {'dict', 'str', 'int'}, set()), # make sure literals don't add names for their contents - ('Literal["a"]', {'Literal'}), - ("Literal['a']", {'Literal'}), - ('Literal[0]', {'Literal'}), - ('Literal[1.0]', {'Literal'}), - ('Literal[True]', {'Literal'}), - ('T | S', {'T', 'S'}), - ('Union[Dict[str, Any], Literal["Foo", "Bar"], _T]', {'Union', 'Dict', 'str', 'Any', 'Literal', '_T'}), + ('Literal["a"]', {'Literal'}, set()), + ("Literal['a']", {'Literal'}, set()), + ('Literal[0]', {'Literal'}, set()), + ('Literal[1.0]', {'Literal'}, set()), + ('Literal[True]', {'Literal'}, set()), + ('L[a]', {'L'}, set()), + ('T | S', {'T', 'S'}, set()), + ('Union[Dict[str, Any], Literal["Foo", "Bar"], _T]', {'Union', 'Dict', 'str', 'Any', 'Literal', '_T'}, set()), # for attribute access only everything up to the first dot should count # this matches the behavior of add_annotation - ('datetime.date | os.path.sep', {'datetime', 'os'}), - ('Nested["str"]', {'Nested', 'str'}), - ('Annotated[str, validator]', {'Annotated', 'str'}), - ('Annotated[str, "bool"]', {'Annotated', 'str'}), + ('datetime.date | os.path.sep', {'datetime', 'os'}, set()), + ('Nested["str"]', {'Nested', 'str'}, set()), + ('Annotated[str, validator(int, 5)]', {'Annotated', 'str'}, {'validator', 'int'}), + ('Annotated[str, "bool"]', {'Annotated', 'str'}, set()), ] if sys.version_info >= (3, 11): examples.extend([ - ('*Ts', {'Ts'}), + ('*Ts', {'Ts'}, set()), ]) -@pytest.mark.parametrize(('example', 'expected'), examples) -def test_name_extraction(example, expected): - visitor = StringAnnotationVisitor() +@pytest.mark.parametrize(('example', 'expected', 'soft_uses'), examples) +def test_name_extraction(example, expected, soft_uses): + import_visitor = ImportVisitor( + cwd='fake cwd', # type: ignore[arg-type] + pydantic_enabled=False, + fastapi_enabled=False, + fastapi_dependency_support_enabled=False, + cattrs_enabled=False, + sqlalchemy_enabled=False, + sqlalchemy_mapped_dotted_names=[], + injector_enabled=False, + pydantic_enabled_baseclass_passlist=[], + ) + import_visitor.visit(ast.parse('from typing import Annotated, Literal, Literal as L')) + visitor = StringAnnotationVisitor(import_visitor) visitor.parse_and_visit_string_annotation(example) assert visitor.names == expected diff --git a/tests/test_sqlalchemy.py b/tests/test_sqlalchemy.py index b8e1079..163fbd7 100644 --- a/tests/test_sqlalchemy.py +++ b/tests/test_sqlalchemy.py @@ -75,6 +75,25 @@ class User: assert _get_error(example, error_code_filter='TC002', type_checking_sqlalchemy_enabled=True) == set() +def test_mapped_soft_uses(): + """ + Everything inside Mapped is a soft-use, including `Annotated` value expresions + as such we can't trigger a TC002 here, despite the only uses being inside + type annotations. + """ + example = textwrap.dedent(''' + from foo import Bar, Gt + from sqlalchemy.orm import Mapped + from typing import Annotated + + class User: + number: Mapped[Annotated[Bar, Gt(2)]] + bar: Bar + validator: Gt + ''') + assert _get_error(example, error_code_filter='TC002', type_checking_sqlalchemy_enabled=True) == set() + + def test_mapped_use_without_runtime_import(): """ Mapped must be available at runtime, so even if it is inside a wrapped annotation diff --git a/tests/test_tc001_to_tc003.py b/tests/test_tc001_to_tc003.py index 4582f64..103681f 100644 --- a/tests/test_tc001_to_tc003.py +++ b/tests/test_tc001_to_tc003.py @@ -203,6 +203,7 @@ class Migration: '''), set(), ), + # Annotated soft use ( textwrap.dedent(f''' from typing import Annotated @@ -210,9 +211,11 @@ class Migration: from {import_} import Depends x: Annotated[str, Depends] + y: Depends '''), set(), ), + # This is not a soft-use, it's just a plain string ( textwrap.dedent(f''' from typing import Annotated @@ -220,8 +223,9 @@ class Migration: from {import_} import Depends x: Annotated[str, "Depends"] + y: Depends '''), - set(), + {'4:0 ' + ERROR.format(module=f'{import_}.Depends')}, ), ] From 9082f43d5b52f5e665a9076b3c38e81895486bec Mon Sep 17 00:00:00 2001 From: David Salvisberg Date: Thu, 12 Dec 2024 10:16:33 +0100 Subject: [PATCH 2/3] Fixes and tests soft uses for `typing.Annotated` --- flake8_type_checking/checker.py | 52 ++++++++++++-- tests/test_name_extraction.py | 38 +++++------ tests/test_name_visitor.py | 117 ++++++++++++++++++++++---------- 3 files changed, 144 insertions(+), 63 deletions(-) diff --git a/flake8_type_checking/checker.py b/flake8_type_checking/checker.py index 7eb1d15..91e166b 100644 --- a/flake8_type_checking/checker.py +++ b/flake8_type_checking/checker.py @@ -85,6 +85,10 @@ def visit_annotation_name(self, node: ast.Name) -> None: def visit_annotation_string(self, node: ast.Constant) -> None: """Visit a string constant inside an annotation.""" + def visit_annotated_type(self, node: ast.expr) -> None: + """Visit a type expression of `typing.Annotated[type, value]`.""" + self.visit(node) + def visit_annotated_value(self, node: ast.expr) -> None: """Visit a value expression of `typing.Annotated[type, value]`.""" return @@ -113,7 +117,7 @@ def visit(self, node: ast.AST) -> None: if elts_node.elts: elts_iter = iter(elts_node.elts) # only visit the first element like a type expression - self.visit(next(elts_iter)) + self.visit_annotated_type(next(elts_iter)) for value_node in elts_iter: self.visit_annotated_value(value_node) else: @@ -863,6 +867,10 @@ def __init__(self, import_visitor: ImportVisitor) -> None: # e.g. AnnAssign.annotation within a function body will never evaluate self.never_evaluates = False + #: Whether or not we're currently in a soft-use context + # e.g. the type expression of `Annotated[type, value] + self.in_soft_use_context = False + self.import_visitor = import_visitor def is_typing(self, node: ast.AST, symbol: str) -> bool: @@ -891,6 +899,9 @@ def visit_annotation_name(self, node: ast.Name) -> None: if self.never_evaluates: return + if self.in_soft_use_context: + self.import_visitor.soft_uses.add(node.id) + self.unwrapped_annotations.append( UnwrappedAnnotation(node.lineno, node.col_offset, node.id, self.scope, self.type) ) @@ -904,12 +915,33 @@ def visit_annotation_string(self, node: ast.Constant) -> None: else: visitor = StringAnnotationVisitor(self.import_visitor) visitor.parse_and_visit_string_annotation(node.value) + if self.in_soft_use_context: + self.import_visitor.soft_uses.update(visitor.names) (self.excess_wrapped_annotations if self.never_evaluates else self.wrapped_annotations).append( WrappedAnnotation(node.lineno, node.col_offset, node.value, visitor.names, self.scope, self.type) ) + def visit_annotated_type(self, node: ast.expr) -> None: + """Visit a type expression of `typing.Annotated[type, value]`.""" + if self.never_evaluates: + return + + previous_context = self.in_soft_use_context + self.in_soft_use_context = True + self.visit(node) + self.in_soft_use_context = previous_context + def visit_annotated_value(self, node: ast.expr) -> None: """Visit a value expression of `typing.Annotated[type, value]`.""" + if self.never_evaluates: + return + + if self.type == 'alias' or (self.type == 'annotation' and not self.import_visitor.futures_annotation): + # visit nodes in regular runtime context + self.import_visitor.visit(node) + return + + # visit nodes in soft use context previous_context = self.import_visitor.in_soft_use_context self.import_visitor.in_soft_use_context = True self.import_visitor.visit(node) @@ -978,13 +1010,14 @@ def __init__( self.uses: dict[str, list[tuple[ast.AST, Scope]]] = defaultdict(list) #: Contains a set of all names to be treated like soft-uses. - # i.e. we don't know if it will be used at runtime or not + # i.e. we don't know if it will be used at runtime or not, so + # we should assume the imports are currently correct self.soft_uses: set[str] = set() #: Handles logic of visiting annotation nodes self.annotation_visitor = ImportAnnotationVisitor(self) - #: Whether there is a `from __futures__ import annotations` is present in the file + #: Whether there is a `from __futures__ import annotations` present in the file self.futures_annotation: Optional[bool] = None #: Where the type checking block exists (line_start, line_end, col_offset) @@ -1000,8 +1033,7 @@ def __init__( #: For tracking which comprehension/IfExp we're currently inside of self.active_context: Optional[Comprehension | ast.IfExp] = None - #: Whether or not we're in a context where uses count as soft-uses. - # E.g. the value expression of `typing.Annotated[type, value]` + #: Whether we're in a value expression of `typing.Annotated[type, value]`. self.in_soft_use_context: bool = False @contextmanager @@ -1387,10 +1419,16 @@ def visit_Name(self, node: ast.Name) -> ast.Name: if self.in_type_checking_block(node.lineno, node.col_offset): return node + names = [node.id] if hasattr(node, ATTRIBUTE_PROPERTY): - self.uses[f'{node.id}.{getattr(node, ATTRIBUTE_PROPERTY)}'].append((node, self.current_scope)) + names.append(f'{node.id}.{getattr(node, ATTRIBUTE_PROPERTY)}') + + if self.in_soft_use_context: + self.soft_uses.update(names) + else: + for name in names: + self.uses[name].append((node, self.current_scope)) - self.uses[node.id].append((node, self.current_scope)) return node def visit_Constant(self, node: ast.Constant) -> ast.Constant: diff --git a/tests/test_name_extraction.py b/tests/test_name_extraction.py index 88689d6..e00aac2 100644 --- a/tests/test_name_extraction.py +++ b/tests/test_name_extraction.py @@ -6,35 +6,35 @@ from flake8_type_checking.checker import ImportVisitor, StringAnnotationVisitor examples = [ - ('', set(), set()), - ('invalid_syntax]', set(), set()), - ('int', {'int'}, set()), - ('dict[str, int]', {'dict', 'str', 'int'}, set()), + ('', set()), + ('invalid_syntax]', set()), + ('int', {'int'}), + ('dict[str, int]', {'dict', 'str', 'int'}), # make sure literals don't add names for their contents - ('Literal["a"]', {'Literal'}, set()), - ("Literal['a']", {'Literal'}, set()), - ('Literal[0]', {'Literal'}, set()), - ('Literal[1.0]', {'Literal'}, set()), - ('Literal[True]', {'Literal'}, set()), - ('L[a]', {'L'}, set()), - ('T | S', {'T', 'S'}, set()), - ('Union[Dict[str, Any], Literal["Foo", "Bar"], _T]', {'Union', 'Dict', 'str', 'Any', 'Literal', '_T'}, set()), + ('Literal["a"]', {'Literal'}), + ("Literal['a']", {'Literal'}), + ('Literal[0]', {'Literal'}), + ('Literal[1.0]', {'Literal'}), + ('Literal[True]', {'Literal'}), + ('L[a]', {'L'}), + ('T | S', {'T', 'S'}), + ('Union[Dict[str, Any], Literal["Foo", "Bar"], _T]', {'Union', 'Dict', 'str', 'Any', 'Literal', '_T'}), # for attribute access only everything up to the first dot should count # this matches the behavior of add_annotation - ('datetime.date | os.path.sep', {'datetime', 'os'}, set()), - ('Nested["str"]', {'Nested', 'str'}, set()), - ('Annotated[str, validator(int, 5)]', {'Annotated', 'str'}, {'validator', 'int'}), - ('Annotated[str, "bool"]', {'Annotated', 'str'}, set()), + ('datetime.date | os.path.sep', {'datetime', 'os'}), + ('Nested["str"]', {'Nested', 'str'}), + ('Annotated[str, validator(int, 5)]', {'Annotated', 'str'}), + ('Annotated[str, "bool"]', {'Annotated', 'str'}), ] if sys.version_info >= (3, 11): examples.extend([ - ('*Ts', {'Ts'}, set()), + ('*Ts', {'Ts'}), ]) -@pytest.mark.parametrize(('example', 'expected', 'soft_uses'), examples) -def test_name_extraction(example, expected, soft_uses): +@pytest.mark.parametrize(('example', 'expected'), examples) +def test_name_extraction(example, expected): import_visitor = ImportVisitor( cwd='fake cwd', # type: ignore[arg-type] pydantic_enabled=False, diff --git a/tests/test_name_visitor.py b/tests/test_name_visitor.py index c282108..4e85ae1 100644 --- a/tests/test_name_visitor.py +++ b/tests/test_name_visitor.py @@ -1,18 +1,15 @@ from __future__ import annotations import ast +import sys import textwrap -from typing import TYPE_CHECKING import pytest from flake8_type_checking.checker import ImportVisitor -if TYPE_CHECKING: - from typing import Set - -def _get_names(example: str) -> Set[str]: +def _get_names_and_soft_uses(example: str) -> tuple[set[str], set[str]]: visitor = ImportVisitor( cwd='fake cwd', # type: ignore[arg-type] pydantic_enabled=False, @@ -25,56 +22,59 @@ def _get_names(example: str) -> Set[str]: pydantic_enabled_baseclass_passlist=[], ) visitor.visit(ast.parse(example)) - return visitor.names + return visitor.names, visitor.soft_uses examples = [ # ast.Import - ('import x', set()), - ('import pytest', set()), - ('import flake8_type_checking', set()), + ('import x', set(), set()), + ('import pytest', set(), set()), + ('import flake8_type_checking', set(), set()), # ast.ImportFrom - ('from x import y', set()), - ('from _pytest import fixtures', set()), - ('from flake8_type_checking import constants', set()), + ('from x import y', set(), set()), + ('from _pytest import fixtures', set(), set()), + ('from flake8_type_checking import constants', set(), set()), # Assignments - ('x = y', {'x', 'y'}), - ('x, y = z', {'x', 'y', 'z'}), - ('x, y, z = a, b, c()', {'x', 'y', 'z', 'a', 'b', 'c'}), + ('x = y', {'x', 'y'}, set()), + ('x, y = z', {'x', 'y', 'z'}, set()), + ('x, y, z = a, b, c()', {'x', 'y', 'z', 'a', 'b', 'c'}, set()), # Calls - ('x()', {'x'}), - ('x = y()', {'x', 'y'}), - ('def example(): x = y(); z()', {'x', 'y', 'z'}), + ('x()', {'x'}, set()), + ('x = y()', {'x', 'y'}, set()), + ('def example(): x = y(); z()', {'x', 'y', 'z'}, set()), # Attribute - ('x.y', {'x.y', 'x'}), + ('x.y', {'x.y', 'x'}, set()), ( textwrap.dedent(""" - def example(c): - a = 2 - b = c * 2 - """), + def example(c): + a = 2 + b = c * 2 + """), {'a', 'b', 'c'}, + set(), ), ( textwrap.dedent(""" - class Test: - x = 13 + class Test: + x = 13 - def __init__(self, z): - self.y = z + def __init__(self, z): + self.y = z - a = Test() - b = a.y - """), + a = Test() + b = a.y + """), {'self.y', 'z', 'Test', 'self', 'a', 'b', 'x', 'a.y'}, + set(), ), ( textwrap.dedent(""" - import ast + import ast - ImportType = Union[Import, ImportFrom] - """), # ast should not be a part of this + ImportType = Union[Import, ImportFrom] + """), # ast should not be a part of this {'Union', 'Import', 'ImportFrom', 'ImportType'}, + set(), ), ( textwrap.dedent(""" @@ -85,13 +85,53 @@ def _get_usages(example): return visitor.usage_names """), {'UnusedImportVisitor', 'example', 'parse', 'visitor', 'visitor.usage_names', 'visitor.visit'}, + set(), + ), + ( + textwrap.dedent(""" + from typing import Annotated + + from foo import Gt + + x: Annotated[int, Gt(5)] + """), + {'Gt'}, + {'int'}, + ), + ( + textwrap.dedent(""" + from __future__ import annotations + + from typing import Annotated + + from foo import Gt + + x: Annotated[int, Gt(5)] + """), + set(), + {'Gt', 'int'}, ), ] +if sys.version_info >= (3, 12): + examples.extend([ + ( + textwrap.dedent(""" + from typing import Annotated + + from foo import Gt + + type x = Annotated[int, Gt(5)] + """), + set(), + {'Gt', 'int'}, + ), + ]) -@pytest.mark.parametrize(('example', 'result'), examples) -def test_basic_annotations_are_removed(example, result): - assert _get_names(example) == result + +@pytest.mark.parametrize(('example', 'result', 'soft_uses'), examples) +def test_basic_annotations_are_removed(example, result, soft_uses): + assert _get_names_and_soft_uses(example) == (result, soft_uses) def test_model_declarations_are_included_in_names(): @@ -106,4 +146,7 @@ class LoanProvider(models.Model): on_delete=models.CASCADE, ) """) - assert _get_names(example) == {'SomeModel', 'fk', 'models', 'models.CASCADE', 'models.ForeignKey', 'models.Model'} + assert _get_names_and_soft_uses(example) == ( + {'SomeModel', 'fk', 'models', 'models.CASCADE', 'models.ForeignKey', 'models.Model'}, + set(), + ) From e587a64f3da461db5e0c2ae4fd32bbcede07c61e Mon Sep 17 00:00:00 2001 From: David Salvisberg Date: Thu, 12 Dec 2024 10:21:56 +0100 Subject: [PATCH 3/3] Reverts accidental docstring change --- flake8_type_checking/checker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flake8_type_checking/checker.py b/flake8_type_checking/checker.py index 91e166b..f3ff3f1 100644 --- a/flake8_type_checking/checker.py +++ b/flake8_type_checking/checker.py @@ -1033,7 +1033,8 @@ def __init__( #: For tracking which comprehension/IfExp we're currently inside of self.active_context: Optional[Comprehension | ast.IfExp] = None - #: Whether we're in a value expression of `typing.Annotated[type, value]`. + #: Whether or not we're in a context where uses count as soft-uses. + # E.g. the type expression of `typing.Annotated[type, value]` self.in_soft_use_context: bool = False @contextmanager