Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Make use of is_typing for Literal/Annotated detection. #195

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 118 additions & 26 deletions flake8_type_checking/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,17 @@ def ast_unparse(node: ast.AST) -> str:
Import,
ImportTypeValue,
Name,
SupportsIsTyping,
)


class AnnotationVisitor(ABC):
"""Simplified node visitor for traversing annotations."""

@abstractmethod
def is_typing(self, node: ast.AST, symbol: str) -> bool:
"""Check if the given node matches the given typing symbol."""

@abstractmethod
def visit_annotation_name(self, node: ast.Name) -> None:
"""Visit a name inside an annotation."""
Expand All @@ -80,6 +85,14 @@ def visit_annotation_name(self, node: ast.Name) -> None:
def visit_annotation_string(self, node: ast.Constant) -> None:
"""Visit a string constant inside an annotation."""

def visit_annotated_type(self, node: ast.expr) -> None:
"""Visit a type expression of `typing.Annotated[type, value]`."""
self.visit(node)

def visit_annotated_value(self, node: ast.expr) -> None:
"""Visit a value expression of `typing.Annotated[type, value]`."""
return

def visit(self, node: ast.AST) -> None:
"""Visit relevant child nodes on an annotation."""
if node is None:
Expand All @@ -95,15 +108,19 @@ def visit(self, node: ast.AST) -> None:
self.visit(node.value)
elif isinstance(node, ast.Subscript):
self.visit(node.value)
if getattr(node.value, 'id', '') == 'Annotated' and isinstance(
if self.is_typing(node.value, 'Literal'):
return
elif self.is_typing(node.value, 'Annotated') and isinstance(
(elts_node := node.slice.value if py38 and isinstance(node.slice, Index) else node.slice),
(ast.Tuple, ast.List),
):
if elts_node.elts:
# only visit the first element
self.visit(elts_node.elts[0])
# TODO: We may want to visit the rest as a soft-runtime use
elif getattr(node.value, 'id', '') != 'Literal':
elts_iter = iter(elts_node.elts)
# only visit the first element like a type expression
self.visit_annotated_type(next(elts_iter))
for value_node in elts_iter:
self.visit_annotated_value(value_node)
else:
self.visit(node.slice)
elif isinstance(node, (ast.Tuple, ast.List)):
for n in node.elts:
Expand Down Expand Up @@ -300,18 +317,29 @@ def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
class SQLAlchemyAnnotationVisitor(AnnotationVisitor):
"""Adds any names in the annotation to mapped names."""

def __init__(self, mapped_names: set[str]) -> None:
self.mapped_names = mapped_names
def __init__(self, sqlalchemy_plugin: SQLAlchemyMixin) -> None:
self.plugin = sqlalchemy_plugin

def is_typing(self, node: ast.AST, symbol: str) -> bool:
"""Check if the given node matches the given typing symbol."""
return self.plugin.is_typing(node, symbol)

def visit_annotated_value(self, node: ast.expr) -> None:
"""Visit a value expression of `typing.Annotated[type, value]`."""
previous_context = self.plugin.in_soft_use_context
self.plugin.in_soft_use_context = True
self.plugin.visit(node)
self.plugin.in_soft_use_context = previous_context

def visit_annotation_name(self, node: ast.Name) -> None:
"""Add name to mapped names."""
self.mapped_names.add(node.id)
self.plugin.soft_uses.add(node.id)

def visit_annotation_string(self, node: ast.Constant) -> None:
"""Add all the names in the string to mapped names."""
visitor = StringAnnotationVisitor()
visitor = StringAnnotationVisitor(self.plugin)
visitor.parse_and_visit_string_annotation(node.value)
self.mapped_names.update(visitor.names)
self.plugin.soft_uses.update(visitor.names)


class SQLAlchemyMixin:
Expand All @@ -333,23 +361,26 @@ class SQLAlchemyMixin:
sqlalchemy_mapped_dotted_names: set[str]
current_scope: Scope
uses: dict[str, list[tuple[ast.AST, Scope]]]
soft_uses: set[str]
in_soft_use_context: bool

def in_type_checking_block(self, lineno: int, col_offset: int) -> bool: # noqa: D102
...

def lookup_full_name(self, node: ast.AST) -> str | None: # noqa: D102
...

def is_typing(self, node: ast.AST, symbol: str) -> bool: # noqa: D102
...

def visit(self, node: ast.AST) -> ast.AST: # noqa: D102
...

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

#: Contains a set of all names used inside `Mapped[...]` annotations
# These are treated like soft-uses, i.e. we don't know if it will be
# used at runtime or not
self.mapped_names: set[str] = set()

#: Used for visiting annotations
self.sqlalchemy_annotation_visitor = SQLAlchemyAnnotationVisitor(self.mapped_names)
self.sqlalchemy_annotation_visitor = SQLAlchemyAnnotationVisitor(self)

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
"""Remove all annotations assigments."""
Expand Down Expand Up @@ -403,9 +434,9 @@ def handle_sqlalchemy_annotation(self, node: ast.AST) -> None:
# add all names contained in the inner part of the annotation
# since this is not as strict as an actual runtime use, we don't
# care if we record too much here
visitor = StringAnnotationVisitor()
visitor = StringAnnotationVisitor(self)
visitor.parse_and_visit_string_annotation(inner)
self.mapped_names.update(visitor.names)
self.soft_uses.update(visitor.names)
return

# we only need to handle annotations like `Mapped[...]`
Expand Down Expand Up @@ -780,9 +811,14 @@ def lookup(self, symbol_name: str, use: HasPosition | None = None, runtime_only:
class StringAnnotationVisitor(AnnotationVisitor):
"""Visit a parsed string annotation and collect all the names."""

def __init__(self) -> None:
def __init__(self, typing_lookup: SupportsIsTyping) -> None:
#: All the names referenced inside the annotation
self.names: set[str] = set()
self._typing_lookup = typing_lookup

def is_typing(self, node: ast.AST, symbol: str) -> bool:
"""Check if the given node matches the given typing symbol."""
return self._typing_lookup.is_typing(node, symbol)

def parse_and_visit_string_annotation(self, annotation: str) -> None:
"""Parse and visit the given string as an annotation expression."""
Expand Down Expand Up @@ -814,7 +850,7 @@ def visit_annotation_string(self, node: ast.Constant) -> None:
class ImportAnnotationVisitor(AnnotationVisitor):
"""Map all annotations on an AST node."""

def __init__(self) -> None:
def __init__(self, import_visitor: ImportVisitor) -> None:
#: All type annotations in the file, without quotes around them
self.unwrapped_annotations: list[UnwrappedAnnotation] = []

Expand All @@ -831,6 +867,16 @@ def __init__(self) -> None:
# e.g. AnnAssign.annotation within a function body will never evaluate
self.never_evaluates = False

#: Whether or not we're currently in a soft-use context
# e.g. the type expression of `Annotated[type, value]
self.in_soft_use_context = False

self.import_visitor = import_visitor

def is_typing(self, node: ast.AST, symbol: str) -> bool:
"""Check if the given node matches the given typing symbol."""
return self.import_visitor.is_typing(node, symbol)

def visit(
self,
node: ast.AST,
Expand All @@ -853,6 +899,9 @@ def visit_annotation_name(self, node: ast.Name) -> None:
if self.never_evaluates:
return

if self.in_soft_use_context:
self.import_visitor.soft_uses.add(node.id)

self.unwrapped_annotations.append(
UnwrappedAnnotation(node.lineno, node.col_offset, node.id, self.scope, self.type)
)
Expand All @@ -864,12 +913,40 @@ def visit_annotation_string(self, node: ast.Constant) -> None:
if getattr(node, BINOP_OPERAND_PROPERTY, False):
self.invalid_binop_literals.append(node)
else:
visitor = StringAnnotationVisitor()
visitor = StringAnnotationVisitor(self.import_visitor)
visitor.parse_and_visit_string_annotation(node.value)
if self.in_soft_use_context:
self.import_visitor.soft_uses.update(visitor.names)
(self.excess_wrapped_annotations if self.never_evaluates else self.wrapped_annotations).append(
WrappedAnnotation(node.lineno, node.col_offset, node.value, visitor.names, self.scope, self.type)
)

def visit_annotated_type(self, node: ast.expr) -> None:
"""Visit a type expression of `typing.Annotated[type, value]`."""
if self.never_evaluates:
return

previous_context = self.in_soft_use_context
self.in_soft_use_context = True
self.visit(node)
self.in_soft_use_context = previous_context

def visit_annotated_value(self, node: ast.expr) -> None:
"""Visit a value expression of `typing.Annotated[type, value]`."""
if self.never_evaluates:
return

if self.type == 'alias' or (self.type == 'annotation' and not self.import_visitor.futures_annotation):
# visit nodes in regular runtime context
self.import_visitor.visit(node)
return

# visit nodes in soft use context
previous_context = self.import_visitor.in_soft_use_context
self.import_visitor.in_soft_use_context = True
self.import_visitor.visit(node)
self.import_visitor.in_soft_use_context = previous_context


class ImportVisitor(
DunderAllMixin, AttrsMixin, InjectorMixin, FastAPIMixin, PydanticMixin, SQLAlchemyMixin, ast.NodeVisitor
Expand Down Expand Up @@ -932,10 +1009,15 @@ def __init__(
#: List of all names and ids, except type declarations
self.uses: dict[str, list[tuple[ast.AST, Scope]]] = defaultdict(list)

#: Contains a set of all names to be treated like soft-uses.
# i.e. we don't know if it will be used at runtime or not, so
# we should assume the imports are currently correct
self.soft_uses: set[str] = set()

#: Handles logic of visiting annotation nodes
self.annotation_visitor = ImportAnnotationVisitor()
self.annotation_visitor = ImportAnnotationVisitor(self)

#: Whether there is a `from __futures__ import annotations` is present in the file
#: Whether there is a `from __futures__ import annotations` present in the file
self.futures_annotation: Optional[bool] = None

#: Where the type checking block exists (line_start, line_end, col_offset)
Expand All @@ -951,6 +1033,10 @@ def __init__(
#: For tracking which comprehension/IfExp we're currently inside of
self.active_context: Optional[Comprehension | ast.IfExp] = None

#: Whether or not we're in a context where uses count as soft-uses.
# E.g. the type expression of `typing.Annotated[type, value]`
self.in_soft_use_context: bool = False

@contextmanager
def create_scope(self, node: ast.ClassDef | Function, is_head: bool = True) -> Iterator[Scope]:
"""Create a new scope."""
Expand Down Expand Up @@ -1334,10 +1420,16 @@ def visit_Name(self, node: ast.Name) -> ast.Name:
if self.in_type_checking_block(node.lineno, node.col_offset):
return node

names = [node.id]
if hasattr(node, ATTRIBUTE_PROPERTY):
self.uses[f'{node.id}.{getattr(node, ATTRIBUTE_PROPERTY)}'].append((node, self.current_scope))
names.append(f'{node.id}.{getattr(node, ATTRIBUTE_PROPERTY)}')

if self.in_soft_use_context:
self.soft_uses.update(names)
else:
for name in names:
self.uses[name].append((node, self.current_scope))

self.uses[node.id].append((node, self.current_scope))
return node

def visit_Constant(self, node: ast.Constant) -> ast.Constant:
Expand Down Expand Up @@ -1853,7 +1945,7 @@ def unused_imports(self) -> Flake8Generator:
}

all_imports = {name for name, imp in self.visitor.imports.items() if not imp.exempt}
unused_imports = all_imports - self.visitor.names - self.visitor.mapped_names
unused_imports = all_imports - self.visitor.names - self.visitor.soft_uses
used_imports = all_imports - unused_imports
already_imported_modules = [self.visitor.imports[name].module for name in used_imports]
annotation_names = (
Expand Down
4 changes: 4 additions & 0 deletions flake8_type_checking/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,9 @@ def lineno(self) -> int:
def col_offset(self) -> int:
pass

class SupportsIsTyping(Protocol):
def is_typing(self, node: ast.AST, symbol: str) -> bool:
pass


ImportTypeValue = Literal['APPLICATION', 'THIRD_PARTY', 'BUILTIN', 'FUTURE']
20 changes: 17 additions & 3 deletions tests/test_name_extraction.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import ast
import sys

import pytest

from flake8_type_checking.checker import StringAnnotationVisitor
from flake8_type_checking.checker import ImportVisitor, StringAnnotationVisitor

examples = [
('', set()),
Expand All @@ -15,13 +16,14 @@
('Literal[0]', {'Literal'}),
('Literal[1.0]', {'Literal'}),
('Literal[True]', {'Literal'}),
('L[a]', {'L'}),
('T | S', {'T', 'S'}),
('Union[Dict[str, Any], Literal["Foo", "Bar"], _T]', {'Union', 'Dict', 'str', 'Any', 'Literal', '_T'}),
# for attribute access only everything up to the first dot should count
# this matches the behavior of add_annotation
('datetime.date | os.path.sep', {'datetime', 'os'}),
('Nested["str"]', {'Nested', 'str'}),
('Annotated[str, validator]', {'Annotated', 'str'}),
('Annotated[str, validator(int, 5)]', {'Annotated', 'str'}),
('Annotated[str, "bool"]', {'Annotated', 'str'}),
]

Expand All @@ -33,6 +35,18 @@

@pytest.mark.parametrize(('example', 'expected'), examples)
def test_name_extraction(example, expected):
visitor = StringAnnotationVisitor()
import_visitor = ImportVisitor(
cwd='fake cwd', # type: ignore[arg-type]
pydantic_enabled=False,
fastapi_enabled=False,
fastapi_dependency_support_enabled=False,
cattrs_enabled=False,
sqlalchemy_enabled=False,
sqlalchemy_mapped_dotted_names=[],
injector_enabled=False,
pydantic_enabled_baseclass_passlist=[],
)
import_visitor.visit(ast.parse('from typing import Annotated, Literal, Literal as L'))
visitor = StringAnnotationVisitor(import_visitor)
visitor.parse_and_visit_string_annotation(example)
assert visitor.names == expected
Loading
Loading