Skip to content

Commit

Permalink
refactor: Make use of is_typing for Literal/Annotated detection.
Browse files Browse the repository at this point in the history
Makes soft use concept of `mapped_names` more generic and reuse it for
walking `Annotated` value expressions.
  • Loading branch information
Daverball committed Dec 9, 2024
1 parent 68d42a8 commit f1c791f
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 43 deletions.
88 changes: 65 additions & 23 deletions flake8_type_checking/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,17 @@ def ast_unparse(node: ast.AST) -> str:
Import,
ImportTypeValue,
Name,
SupportsIsTyping,
)


class AnnotationVisitor(ABC):
"""Simplified node visitor for traversing annotations."""

@abstractmethod
def is_typing(self, node: ast.AST, symbol: str) -> bool:
"""Check if the given node matches the given typing symbol."""

@abstractmethod
def visit_annotation_name(self, node: ast.Name) -> None:
"""Visit a name inside an annotation."""
Expand All @@ -80,6 +85,10 @@ def visit_annotation_name(self, node: ast.Name) -> None:
def visit_annotation_string(self, node: ast.Constant) -> None:
"""Visit a string constant inside an annotation."""

def visit_annotated_value(self, node: ast.expr) -> None:
"""Visit a value expression of `typing.Annotated[type, value]`."""
return

def visit(self, node: ast.AST) -> None:
"""Visit relevant child nodes on an annotation."""
if node is None:
Expand All @@ -95,15 +104,19 @@ def visit(self, node: ast.AST) -> None:
self.visit(node.value)
elif isinstance(node, ast.Subscript):
self.visit(node.value)
if getattr(node.value, 'id', '') == 'Annotated' and isinstance(
if self.is_typing(node.value, 'Literal'):
return
elif self.is_typing(node.value, 'Annotated') and isinstance(
(elts_node := node.slice.value if py38 and isinstance(node.slice, Index) else node.slice),
(ast.Tuple, ast.List),
):
if elts_node.elts:
# only visit the first element
self.visit(elts_node.elts[0])
# TODO: We may want to visit the rest as a soft-runtime use
elif getattr(node.value, 'id', '') != 'Literal':
elts_iter = iter(elts_node.elts)
# only visit the first element like a type expression
self.visit(next(elts_iter))
for value_node in elts_iter:
self.visit_annotated_value(value_node)
else:
self.visit(node.slice)
elif isinstance(node, (ast.Tuple, ast.List)):
for n in node.elts:
Expand Down Expand Up @@ -300,18 +313,22 @@ def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
class SQLAlchemyAnnotationVisitor(AnnotationVisitor):
"""Adds any names in the annotation to mapped names."""

def __init__(self, mapped_names: set[str]) -> None:
self.mapped_names = mapped_names
def __init__(self, sqlalchemy_plugin: SQLAlchemyMixin) -> None:
self.plugin = sqlalchemy_plugin

def is_typing(self, node: ast.AST, symbol: str) -> bool:
"""Check if the given node matches the given typing symbol."""
return self.plugin.is_typing(node, symbol)

def visit_annotation_name(self, node: ast.Name) -> None:
"""Add name to mapped names."""
self.mapped_names.add(node.id)
self.plugin.soft_uses.add(node.id)

def visit_annotation_string(self, node: ast.Constant) -> None:
"""Add all the names in the string to mapped names."""
visitor = StringAnnotationVisitor()
visitor = StringAnnotationVisitor(self.plugin)
visitor.parse_and_visit_string_annotation(node.value)
self.mapped_names.update(visitor.names)
self.plugin.soft_uses.update(visitor.names)


class SQLAlchemyMixin:
Expand All @@ -333,23 +350,22 @@ class SQLAlchemyMixin:
sqlalchemy_mapped_dotted_names: set[str]
current_scope: Scope
uses: dict[str, list[tuple[ast.AST, Scope]]]
soft_uses: set[str]

def in_type_checking_block(self, lineno: int, col_offset: int) -> bool: # noqa: D102
...

def lookup_full_name(self, node: ast.AST) -> str | None: # noqa: D102
...

def is_typing(self, node: ast.AST, symbol: str) -> bool: # noqa: D102
...

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

#: Contains a set of all names used inside `Mapped[...]` annotations
# These are treated like soft-uses, i.e. we don't know if it will be
# used at runtime or not
self.mapped_names: set[str] = set()

#: Used for visiting annotations
self.sqlalchemy_annotation_visitor = SQLAlchemyAnnotationVisitor(self.mapped_names)
self.sqlalchemy_annotation_visitor = SQLAlchemyAnnotationVisitor(self)

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
"""Remove all annotations assigments."""
Expand Down Expand Up @@ -403,9 +419,9 @@ def handle_sqlalchemy_annotation(self, node: ast.AST) -> None:
# add all names contained in the inner part of the annotation
# since this is not as strict as an actual runtime use, we don't
# care if we record too much here
visitor = StringAnnotationVisitor()
visitor = StringAnnotationVisitor(self)
visitor.parse_and_visit_string_annotation(inner)
self.mapped_names.update(visitor.names)
self.soft_uses.update(visitor.names)
return

# we only need to handle annotations like `Mapped[...]`
Expand Down Expand Up @@ -780,9 +796,14 @@ def lookup(self, symbol_name: str, use: HasPosition | None = None, runtime_only:
class StringAnnotationVisitor(AnnotationVisitor):
"""Visit a parsed string annotation and collect all the names."""

def __init__(self) -> None:
def __init__(self, typing_lookup: SupportsIsTyping) -> None:
#: All the names referenced inside the annotation
self.names: set[str] = set()
self._typing_lookup = typing_lookup

def is_typing(self, node: ast.AST, symbol: str) -> bool:
"""Check if the given node matches the given typing symbol."""
return self._typing_lookup.is_typing(node, symbol)

def parse_and_visit_string_annotation(self, annotation: str) -> None:
"""Parse and visit the given string as an annotation expression."""
Expand Down Expand Up @@ -814,7 +835,7 @@ def visit_annotation_string(self, node: ast.Constant) -> None:
class ImportAnnotationVisitor(AnnotationVisitor):
"""Map all annotations on an AST node."""

def __init__(self) -> None:
def __init__(self, import_visitor: ImportVisitor) -> None:
#: All type annotations in the file, without quotes around them
self.unwrapped_annotations: list[UnwrappedAnnotation] = []

Expand All @@ -831,6 +852,12 @@ def __init__(self) -> None:
# e.g. AnnAssign.annotation within a function body will never evaluate
self.never_evaluates = False

self.import_visitor = import_visitor

def is_typing(self, node: ast.AST, symbol: str) -> bool:
"""Check if the given node matches the given typing symbol."""
return self.import_visitor.is_typing(node, symbol)

def visit(
self,
node: ast.AST,
Expand Down Expand Up @@ -864,12 +891,19 @@ def visit_annotation_string(self, node: ast.Constant) -> None:
if getattr(node, BINOP_OPERAND_PROPERTY, False):
self.invalid_binop_literals.append(node)
else:
visitor = StringAnnotationVisitor()
visitor = StringAnnotationVisitor(self.import_visitor)
visitor.parse_and_visit_string_annotation(node.value)
(self.excess_wrapped_annotations if self.never_evaluates else self.wrapped_annotations).append(
WrappedAnnotation(node.lineno, node.col_offset, node.value, visitor.names, self.scope, self.type)
)

def visit_annotated_value(self, node: ast.expr) -> None:
"""Visit a value expression of `typing.Annotated[type, value]`."""
previous_context = self.import_visitor.in_soft_use_context
self.import_visitor.in_soft_use_context = True
self.import_visitor.visit(node)
self.import_visitor.in_soft_use_context = previous_context


class ImportVisitor(
DunderAllMixin, AttrsMixin, InjectorMixin, FastAPIMixin, PydanticMixin, SQLAlchemyMixin, ast.NodeVisitor
Expand Down Expand Up @@ -932,8 +966,12 @@ def __init__(
#: List of all names and ids, except type declarations
self.uses: dict[str, list[tuple[ast.AST, Scope]]] = defaultdict(list)

#: Contains a set of all names to be treated like soft-uses.
# i.e. we don't know if it will be used at runtime or not
self.soft_uses: set[str] = set()

#: Handles logic of visiting annotation nodes
self.annotation_visitor = ImportAnnotationVisitor()
self.annotation_visitor = ImportAnnotationVisitor(self)

#: Whether there is a `from __futures__ import annotations` is present in the file
self.futures_annotation: Optional[bool] = None
Expand All @@ -951,6 +989,10 @@ def __init__(
#: For tracking which comprehension/IfExp we're currently inside of
self.active_context: Optional[Comprehension | ast.IfExp] = None

#: Whether or not we're in a context where uses count as soft-uses.
# E.g. the value expression of `typing.Annotated[type, value]`
self.in_soft_use_context: bool = False

@contextmanager
def create_scope(self, node: ast.ClassDef | Function, is_head: bool = True) -> Iterator[Scope]:
"""Create a new scope."""
Expand Down Expand Up @@ -1853,7 +1895,7 @@ def unused_imports(self) -> Flake8Generator:
}

all_imports = {name for name, imp in self.visitor.imports.items() if not imp.exempt}
unused_imports = all_imports - self.visitor.names - self.visitor.mapped_names
unused_imports = all_imports - self.visitor.names - self.visitor.soft_uses
used_imports = all_imports - unused_imports
already_imported_modules = [self.visitor.imports[name].module for name in used_imports]
annotation_names = (
Expand Down
4 changes: 4 additions & 0 deletions flake8_type_checking/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,9 @@ def lineno(self) -> int:
def col_offset(self) -> int:
pass

class SupportsIsTyping(Protocol):
def is_typing(self, node: ast.AST, symbol: str) -> bool:
pass


ImportTypeValue = Literal['APPLICATION', 'THIRD_PARTY', 'BUILTIN', 'FUTURE']
52 changes: 33 additions & 19 deletions tests/test_name_extraction.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
import ast
import sys

import pytest

from flake8_type_checking.checker import StringAnnotationVisitor
from flake8_type_checking.checker import ImportVisitor, StringAnnotationVisitor

examples = [
('', set()),
('invalid_syntax]', set()),
('int', {'int'}),
('dict[str, int]', {'dict', 'str', 'int'}),
('', set(), set()),
('invalid_syntax]', set(), set()),
('int', {'int'}, set()),
('dict[str, int]', {'dict', 'str', 'int'}, set()),
# make sure literals don't add names for their contents
('Literal["a"]', {'Literal'}),
("Literal['a']", {'Literal'}),
('Literal[0]', {'Literal'}),
('Literal[1.0]', {'Literal'}),
('Literal[True]', {'Literal'}),
('T | S', {'T', 'S'}),
('Union[Dict[str, Any], Literal["Foo", "Bar"], _T]', {'Union', 'Dict', 'str', 'Any', 'Literal', '_T'}),
('Literal["a"]', {'Literal'}, set()),
("Literal['a']", {'Literal'}, set()),
('Literal[0]', {'Literal'}, set()),
('Literal[1.0]', {'Literal'}, set()),
('Literal[True]', {'Literal'}, set()),
('L[a]', {'L'}, set()),
('T | S', {'T', 'S'}, set()),
('Union[Dict[str, Any], Literal["Foo", "Bar"], _T]', {'Union', 'Dict', 'str', 'Any', 'Literal', '_T'}, set()),
# for attribute access only everything up to the first dot should count
# this matches the behavior of add_annotation
('datetime.date | os.path.sep', {'datetime', 'os'}),
('Nested["str"]', {'Nested', 'str'}),
('Annotated[str, validator]', {'Annotated', 'str'}),
('Annotated[str, "bool"]', {'Annotated', 'str'}),
('datetime.date | os.path.sep', {'datetime', 'os'}, set()),
('Nested["str"]', {'Nested', 'str'}, set()),
('Annotated[str, validator(int, 5)]', {'Annotated', 'str'}, {'validator', 'int'}),
('Annotated[str, "bool"]', {'Annotated', 'str'}, set()),
]

if sys.version_info >= (3, 11):
Expand All @@ -31,8 +33,20 @@
])


@pytest.mark.parametrize(('example', 'expected'), examples)
def test_name_extraction(example, expected):
visitor = StringAnnotationVisitor()
@pytest.mark.parametrize(('example', 'expected', 'soft_uses'), examples)
def test_name_extraction(example, expected, soft_uses):
import_visitor = ImportVisitor(
cwd='fake cwd', # type: ignore[arg-type]
pydantic_enabled=False,
fastapi_enabled=False,
fastapi_dependency_support_enabled=False,
cattrs_enabled=False,
sqlalchemy_enabled=False,
sqlalchemy_mapped_dotted_names=[],
injector_enabled=False,
pydantic_enabled_baseclass_passlist=[],
)
import_visitor.visit(ast.parse('from typing import Annotated, Literal, Literal as L'))
visitor = StringAnnotationVisitor(import_visitor)
visitor.parse_and_visit_string_annotation(example)
assert visitor.names == expected
6 changes: 5 additions & 1 deletion tests/test_tc001_to_tc003.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,25 +203,29 @@ class Migration:
'''),
set(),
),
# Annotated soft use
(
textwrap.dedent(f'''
from typing import Annotated
from {import_} import Depends
x: Annotated[str, Depends]
y: Depends
'''),
set(),
),
# This is not a soft-use, it's just a plain string
(
textwrap.dedent(f'''
from typing import Annotated
from {import_} import Depends
x: Annotated[str, "Depends"]
y: Depends
'''),
set(),
{'4:0 ' + ERROR.format(module=f'{import_}.Depends')},
),
]

Expand Down

0 comments on commit f1c791f

Please sign in to comment.