Skip to content

Commit

Permalink
refactor: Make use of is_typing for Literal/Annotated detection.
Browse files Browse the repository at this point in the history
Makes soft use concept of `mapped_names` more generic and reuse it for
walking `Annotated` type and/or value expressions depending on context.
  • Loading branch information
Daverball authored Dec 12, 2024
1 parent 68d42a8 commit 7cf333d
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 67 deletions.
144 changes: 118 additions & 26 deletions flake8_type_checking/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,17 @@ def ast_unparse(node: ast.AST) -> str:
Import,
ImportTypeValue,
Name,
SupportsIsTyping,
)


class AnnotationVisitor(ABC):
"""Simplified node visitor for traversing annotations."""

@abstractmethod
def is_typing(self, node: ast.AST, symbol: str) -> bool:
"""Check if the given node matches the given typing symbol."""

@abstractmethod
def visit_annotation_name(self, node: ast.Name) -> None:
"""Visit a name inside an annotation."""
Expand All @@ -80,6 +85,14 @@ def visit_annotation_name(self, node: ast.Name) -> None:
def visit_annotation_string(self, node: ast.Constant) -> None:
"""Visit a string constant inside an annotation."""

def visit_annotated_type(self, node: ast.expr) -> None:
"""Visit a type expression of `typing.Annotated[type, value]`."""
self.visit(node)

def visit_annotated_value(self, node: ast.expr) -> None:
"""Visit a value expression of `typing.Annotated[type, value]`."""
return

def visit(self, node: ast.AST) -> None:
"""Visit relevant child nodes on an annotation."""
if node is None:
Expand All @@ -95,15 +108,19 @@ def visit(self, node: ast.AST) -> None:
self.visit(node.value)
elif isinstance(node, ast.Subscript):
self.visit(node.value)
if getattr(node.value, 'id', '') == 'Annotated' and isinstance(
if self.is_typing(node.value, 'Literal'):
return
elif self.is_typing(node.value, 'Annotated') and isinstance(
(elts_node := node.slice.value if py38 and isinstance(node.slice, Index) else node.slice),
(ast.Tuple, ast.List),
):
if elts_node.elts:
# only visit the first element
self.visit(elts_node.elts[0])
# TODO: We may want to visit the rest as a soft-runtime use
elif getattr(node.value, 'id', '') != 'Literal':
elts_iter = iter(elts_node.elts)
# only visit the first element like a type expression
self.visit_annotated_type(next(elts_iter))
for value_node in elts_iter:
self.visit_annotated_value(value_node)
else:
self.visit(node.slice)
elif isinstance(node, (ast.Tuple, ast.List)):
for n in node.elts:
Expand Down Expand Up @@ -300,18 +317,29 @@ def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
class SQLAlchemyAnnotationVisitor(AnnotationVisitor):
"""Adds any names in the annotation to mapped names."""

def __init__(self, mapped_names: set[str]) -> None:
self.mapped_names = mapped_names
def __init__(self, sqlalchemy_plugin: SQLAlchemyMixin) -> None:
self.plugin = sqlalchemy_plugin

def is_typing(self, node: ast.AST, symbol: str) -> bool:
"""Check if the given node matches the given typing symbol."""
return self.plugin.is_typing(node, symbol)

def visit_annotated_value(self, node: ast.expr) -> None:
"""Visit a value expression of `typing.Annotated[type, value]`."""
previous_context = self.plugin.in_soft_use_context
self.plugin.in_soft_use_context = True
self.plugin.visit(node)
self.plugin.in_soft_use_context = previous_context

def visit_annotation_name(self, node: ast.Name) -> None:
"""Add name to mapped names."""
self.mapped_names.add(node.id)
self.plugin.soft_uses.add(node.id)

def visit_annotation_string(self, node: ast.Constant) -> None:
"""Add all the names in the string to mapped names."""
visitor = StringAnnotationVisitor()
visitor = StringAnnotationVisitor(self.plugin)
visitor.parse_and_visit_string_annotation(node.value)
self.mapped_names.update(visitor.names)
self.plugin.soft_uses.update(visitor.names)


class SQLAlchemyMixin:
Expand All @@ -333,23 +361,26 @@ class SQLAlchemyMixin:
sqlalchemy_mapped_dotted_names: set[str]
current_scope: Scope
uses: dict[str, list[tuple[ast.AST, Scope]]]
soft_uses: set[str]
in_soft_use_context: bool

def in_type_checking_block(self, lineno: int, col_offset: int) -> bool: # noqa: D102
...

def lookup_full_name(self, node: ast.AST) -> str | None: # noqa: D102
...

def is_typing(self, node: ast.AST, symbol: str) -> bool: # noqa: D102
...

def visit(self, node: ast.AST) -> ast.AST: # noqa: D102
...

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

#: Contains a set of all names used inside `Mapped[...]` annotations
# These are treated like soft-uses, i.e. we don't know if it will be
# used at runtime or not
self.mapped_names: set[str] = set()

#: Used for visiting annotations
self.sqlalchemy_annotation_visitor = SQLAlchemyAnnotationVisitor(self.mapped_names)
self.sqlalchemy_annotation_visitor = SQLAlchemyAnnotationVisitor(self)

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
"""Remove all annotations assigments."""
Expand Down Expand Up @@ -403,9 +434,9 @@ def handle_sqlalchemy_annotation(self, node: ast.AST) -> None:
# add all names contained in the inner part of the annotation
# since this is not as strict as an actual runtime use, we don't
# care if we record too much here
visitor = StringAnnotationVisitor()
visitor = StringAnnotationVisitor(self)
visitor.parse_and_visit_string_annotation(inner)
self.mapped_names.update(visitor.names)
self.soft_uses.update(visitor.names)
return

# we only need to handle annotations like `Mapped[...]`
Expand Down Expand Up @@ -780,9 +811,14 @@ def lookup(self, symbol_name: str, use: HasPosition | None = None, runtime_only:
class StringAnnotationVisitor(AnnotationVisitor):
"""Visit a parsed string annotation and collect all the names."""

def __init__(self) -> None:
def __init__(self, typing_lookup: SupportsIsTyping) -> None:
#: All the names referenced inside the annotation
self.names: set[str] = set()
self._typing_lookup = typing_lookup

def is_typing(self, node: ast.AST, symbol: str) -> bool:
"""Check if the given node matches the given typing symbol."""
return self._typing_lookup.is_typing(node, symbol)

def parse_and_visit_string_annotation(self, annotation: str) -> None:
"""Parse and visit the given string as an annotation expression."""
Expand Down Expand Up @@ -814,7 +850,7 @@ def visit_annotation_string(self, node: ast.Constant) -> None:
class ImportAnnotationVisitor(AnnotationVisitor):
"""Map all annotations on an AST node."""

def __init__(self) -> None:
def __init__(self, import_visitor: ImportVisitor) -> None:
#: All type annotations in the file, without quotes around them
self.unwrapped_annotations: list[UnwrappedAnnotation] = []

Expand All @@ -831,6 +867,16 @@ def __init__(self) -> None:
# e.g. AnnAssign.annotation within a function body will never evaluate
self.never_evaluates = False

#: Whether or not we're currently in a soft-use context
# e.g. the type expression of `Annotated[type, value]
self.in_soft_use_context = False

self.import_visitor = import_visitor

def is_typing(self, node: ast.AST, symbol: str) -> bool:
"""Check if the given node matches the given typing symbol."""
return self.import_visitor.is_typing(node, symbol)

def visit(
self,
node: ast.AST,
Expand All @@ -853,6 +899,9 @@ def visit_annotation_name(self, node: ast.Name) -> None:
if self.never_evaluates:
return

if self.in_soft_use_context:
self.import_visitor.soft_uses.add(node.id)

self.unwrapped_annotations.append(
UnwrappedAnnotation(node.lineno, node.col_offset, node.id, self.scope, self.type)
)
Expand All @@ -864,12 +913,40 @@ def visit_annotation_string(self, node: ast.Constant) -> None:
if getattr(node, BINOP_OPERAND_PROPERTY, False):
self.invalid_binop_literals.append(node)
else:
visitor = StringAnnotationVisitor()
visitor = StringAnnotationVisitor(self.import_visitor)
visitor.parse_and_visit_string_annotation(node.value)
if self.in_soft_use_context:
self.import_visitor.soft_uses.update(visitor.names)
(self.excess_wrapped_annotations if self.never_evaluates else self.wrapped_annotations).append(
WrappedAnnotation(node.lineno, node.col_offset, node.value, visitor.names, self.scope, self.type)
)

def visit_annotated_type(self, node: ast.expr) -> None:
"""Visit a type expression of `typing.Annotated[type, value]`."""
if self.never_evaluates:
return

previous_context = self.in_soft_use_context
self.in_soft_use_context = True
self.visit(node)
self.in_soft_use_context = previous_context

def visit_annotated_value(self, node: ast.expr) -> None:
"""Visit a value expression of `typing.Annotated[type, value]`."""
if self.never_evaluates:
return

if self.type == 'alias' or (self.type == 'annotation' and not self.import_visitor.futures_annotation):
# visit nodes in regular runtime context
self.import_visitor.visit(node)
return

# visit nodes in soft use context
previous_context = self.import_visitor.in_soft_use_context
self.import_visitor.in_soft_use_context = True
self.import_visitor.visit(node)
self.import_visitor.in_soft_use_context = previous_context


class ImportVisitor(
DunderAllMixin, AttrsMixin, InjectorMixin, FastAPIMixin, PydanticMixin, SQLAlchemyMixin, ast.NodeVisitor
Expand Down Expand Up @@ -932,10 +1009,15 @@ def __init__(
#: List of all names and ids, except type declarations
self.uses: dict[str, list[tuple[ast.AST, Scope]]] = defaultdict(list)

#: Contains a set of all names to be treated like soft-uses.
# i.e. we don't know if it will be used at runtime or not, so
# we should assume the imports are currently correct
self.soft_uses: set[str] = set()

#: Handles logic of visiting annotation nodes
self.annotation_visitor = ImportAnnotationVisitor()
self.annotation_visitor = ImportAnnotationVisitor(self)

#: Whether there is a `from __futures__ import annotations` is present in the file
#: Whether there is a `from __futures__ import annotations` present in the file
self.futures_annotation: Optional[bool] = None

#: Where the type checking block exists (line_start, line_end, col_offset)
Expand All @@ -951,6 +1033,10 @@ def __init__(
#: For tracking which comprehension/IfExp we're currently inside of
self.active_context: Optional[Comprehension | ast.IfExp] = None

#: Whether or not we're in a context where uses count as soft-uses.
# E.g. the type expression of `typing.Annotated[type, value]`
self.in_soft_use_context: bool = False

@contextmanager
def create_scope(self, node: ast.ClassDef | Function, is_head: bool = True) -> Iterator[Scope]:
"""Create a new scope."""
Expand Down Expand Up @@ -1334,10 +1420,16 @@ def visit_Name(self, node: ast.Name) -> ast.Name:
if self.in_type_checking_block(node.lineno, node.col_offset):
return node

names = [node.id]
if hasattr(node, ATTRIBUTE_PROPERTY):
self.uses[f'{node.id}.{getattr(node, ATTRIBUTE_PROPERTY)}'].append((node, self.current_scope))
names.append(f'{node.id}.{getattr(node, ATTRIBUTE_PROPERTY)}')

if self.in_soft_use_context:
self.soft_uses.update(names)
else:
for name in names:
self.uses[name].append((node, self.current_scope))

self.uses[node.id].append((node, self.current_scope))
return node

def visit_Constant(self, node: ast.Constant) -> ast.Constant:
Expand Down Expand Up @@ -1853,7 +1945,7 @@ def unused_imports(self) -> Flake8Generator:
}

all_imports = {name for name, imp in self.visitor.imports.items() if not imp.exempt}
unused_imports = all_imports - self.visitor.names - self.visitor.mapped_names
unused_imports = all_imports - self.visitor.names - self.visitor.soft_uses
used_imports = all_imports - unused_imports
already_imported_modules = [self.visitor.imports[name].module for name in used_imports]
annotation_names = (
Expand Down
4 changes: 4 additions & 0 deletions flake8_type_checking/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,9 @@ def lineno(self) -> int:
def col_offset(self) -> int:
pass

class SupportsIsTyping(Protocol):
def is_typing(self, node: ast.AST, symbol: str) -> bool:
pass


ImportTypeValue = Literal['APPLICATION', 'THIRD_PARTY', 'BUILTIN', 'FUTURE']
20 changes: 17 additions & 3 deletions tests/test_name_extraction.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import ast
import sys

import pytest

from flake8_type_checking.checker import StringAnnotationVisitor
from flake8_type_checking.checker import ImportVisitor, StringAnnotationVisitor

examples = [
('', set()),
Expand All @@ -15,13 +16,14 @@
('Literal[0]', {'Literal'}),
('Literal[1.0]', {'Literal'}),
('Literal[True]', {'Literal'}),
('L[a]', {'L'}),
('T | S', {'T', 'S'}),
('Union[Dict[str, Any], Literal["Foo", "Bar"], _T]', {'Union', 'Dict', 'str', 'Any', 'Literal', '_T'}),
# for attribute access only everything up to the first dot should count
# this matches the behavior of add_annotation
('datetime.date | os.path.sep', {'datetime', 'os'}),
('Nested["str"]', {'Nested', 'str'}),
('Annotated[str, validator]', {'Annotated', 'str'}),
('Annotated[str, validator(int, 5)]', {'Annotated', 'str'}),
('Annotated[str, "bool"]', {'Annotated', 'str'}),
]

Expand All @@ -33,6 +35,18 @@

@pytest.mark.parametrize(('example', 'expected'), examples)
def test_name_extraction(example, expected):
visitor = StringAnnotationVisitor()
import_visitor = ImportVisitor(
cwd='fake cwd', # type: ignore[arg-type]
pydantic_enabled=False,
fastapi_enabled=False,
fastapi_dependency_support_enabled=False,
cattrs_enabled=False,
sqlalchemy_enabled=False,
sqlalchemy_mapped_dotted_names=[],
injector_enabled=False,
pydantic_enabled_baseclass_passlist=[],
)
import_visitor.visit(ast.parse('from typing import Annotated, Literal, Literal as L'))
visitor = StringAnnotationVisitor(import_visitor)
visitor.parse_and_visit_string_annotation(example)
assert visitor.names == expected
Loading

0 comments on commit 7cf333d

Please sign in to comment.