Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Make use of is_typing for Literal/Annotated detection. #195

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 76 additions & 23 deletions flake8_type_checking/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,17 @@ def ast_unparse(node: ast.AST) -> str:
Import,
ImportTypeValue,
Name,
SupportsIsTyping,
)


class AnnotationVisitor(ABC):
"""Simplified node visitor for traversing annotations."""

@abstractmethod
def is_typing(self, node: ast.AST, symbol: str) -> bool:
"""Check if the given node matches the given typing symbol."""

@abstractmethod
def visit_annotation_name(self, node: ast.Name) -> None:
"""Visit a name inside an annotation."""
Expand All @@ -80,6 +85,10 @@ def visit_annotation_name(self, node: ast.Name) -> None:
def visit_annotation_string(self, node: ast.Constant) -> None:
"""Visit a string constant inside an annotation."""

def visit_annotated_value(self, node: ast.expr) -> None:
"""Visit a value expression of `typing.Annotated[type, value]`."""
return

def visit(self, node: ast.AST) -> None:
"""Visit relevant child nodes on an annotation."""
if node is None:
Expand All @@ -95,15 +104,19 @@ def visit(self, node: ast.AST) -> None:
self.visit(node.value)
elif isinstance(node, ast.Subscript):
self.visit(node.value)
if getattr(node.value, 'id', '') == 'Annotated' and isinstance(
if self.is_typing(node.value, 'Literal'):
return
elif self.is_typing(node.value, 'Annotated') and isinstance(
(elts_node := node.slice.value if py38 and isinstance(node.slice, Index) else node.slice),
(ast.Tuple, ast.List),
):
if elts_node.elts:
# only visit the first element
self.visit(elts_node.elts[0])
# TODO: We may want to visit the rest as a soft-runtime use
elif getattr(node.value, 'id', '') != 'Literal':
elts_iter = iter(elts_node.elts)
# only visit the first element like a type expression
self.visit(next(elts_iter))
for value_node in elts_iter:
self.visit_annotated_value(value_node)
else:
self.visit(node.slice)
elif isinstance(node, (ast.Tuple, ast.List)):
for n in node.elts:
Expand Down Expand Up @@ -300,18 +313,29 @@ def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
class SQLAlchemyAnnotationVisitor(AnnotationVisitor):
"""Adds any names in the annotation to mapped names."""

def __init__(self, mapped_names: set[str]) -> None:
self.mapped_names = mapped_names
def __init__(self, sqlalchemy_plugin: SQLAlchemyMixin) -> None:
self.plugin = sqlalchemy_plugin

def is_typing(self, node: ast.AST, symbol: str) -> bool:
"""Check if the given node matches the given typing symbol."""
return self.plugin.is_typing(node, symbol)

def visit_annotated_value(self, node: ast.expr) -> None:
"""Visit a value expression of `typing.Annotated[type, value]`."""
previous_context = self.plugin.in_soft_use_context
self.plugin.in_soft_use_context = True
self.plugin.visit(node)
self.plugin.in_soft_use_context = previous_context

def visit_annotation_name(self, node: ast.Name) -> None:
"""Add name to mapped names."""
self.mapped_names.add(node.id)
self.plugin.soft_uses.add(node.id)

def visit_annotation_string(self, node: ast.Constant) -> None:
"""Add all the names in the string to mapped names."""
visitor = StringAnnotationVisitor()
visitor = StringAnnotationVisitor(self.plugin)
visitor.parse_and_visit_string_annotation(node.value)
self.mapped_names.update(visitor.names)
self.plugin.soft_uses.update(visitor.names)


class SQLAlchemyMixin:
Expand All @@ -333,23 +357,26 @@ class SQLAlchemyMixin:
sqlalchemy_mapped_dotted_names: set[str]
current_scope: Scope
uses: dict[str, list[tuple[ast.AST, Scope]]]
soft_uses: set[str]
in_soft_use_context: bool

def in_type_checking_block(self, lineno: int, col_offset: int) -> bool: # noqa: D102
...

def lookup_full_name(self, node: ast.AST) -> str | None: # noqa: D102
...

def is_typing(self, node: ast.AST, symbol: str) -> bool: # noqa: D102
...

def visit(self, node: ast.AST) -> ast.AST: # noqa: D102
...

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

#: Contains a set of all names used inside `Mapped[...]` annotations
# These are treated like soft-uses, i.e. we don't know if it will be
# used at runtime or not
self.mapped_names: set[str] = set()

#: Used for visiting annotations
self.sqlalchemy_annotation_visitor = SQLAlchemyAnnotationVisitor(self.mapped_names)
self.sqlalchemy_annotation_visitor = SQLAlchemyAnnotationVisitor(self)

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
"""Remove all annotations assigments."""
Expand Down Expand Up @@ -403,9 +430,9 @@ def handle_sqlalchemy_annotation(self, node: ast.AST) -> None:
# add all names contained in the inner part of the annotation
# since this is not as strict as an actual runtime use, we don't
# care if we record too much here
visitor = StringAnnotationVisitor()
visitor = StringAnnotationVisitor(self)
visitor.parse_and_visit_string_annotation(inner)
self.mapped_names.update(visitor.names)
self.soft_uses.update(visitor.names)
return

# we only need to handle annotations like `Mapped[...]`
Expand Down Expand Up @@ -780,9 +807,14 @@ def lookup(self, symbol_name: str, use: HasPosition | None = None, runtime_only:
class StringAnnotationVisitor(AnnotationVisitor):
"""Visit a parsed string annotation and collect all the names."""

def __init__(self) -> None:
def __init__(self, typing_lookup: SupportsIsTyping) -> None:
#: All the names referenced inside the annotation
self.names: set[str] = set()
self._typing_lookup = typing_lookup

def is_typing(self, node: ast.AST, symbol: str) -> bool:
"""Check if the given node matches the given typing symbol."""
return self._typing_lookup.is_typing(node, symbol)

def parse_and_visit_string_annotation(self, annotation: str) -> None:
"""Parse and visit the given string as an annotation expression."""
Expand Down Expand Up @@ -814,7 +846,7 @@ def visit_annotation_string(self, node: ast.Constant) -> None:
class ImportAnnotationVisitor(AnnotationVisitor):
"""Map all annotations on an AST node."""

def __init__(self) -> None:
def __init__(self, import_visitor: ImportVisitor) -> None:
#: All type annotations in the file, without quotes around them
self.unwrapped_annotations: list[UnwrappedAnnotation] = []

Expand All @@ -831,6 +863,12 @@ def __init__(self) -> None:
# e.g. AnnAssign.annotation within a function body will never evaluate
self.never_evaluates = False

self.import_visitor = import_visitor

def is_typing(self, node: ast.AST, symbol: str) -> bool:
"""Check if the given node matches the given typing symbol."""
return self.import_visitor.is_typing(node, symbol)

def visit(
self,
node: ast.AST,
Expand Down Expand Up @@ -864,12 +902,19 @@ def visit_annotation_string(self, node: ast.Constant) -> None:
if getattr(node, BINOP_OPERAND_PROPERTY, False):
self.invalid_binop_literals.append(node)
else:
visitor = StringAnnotationVisitor()
visitor = StringAnnotationVisitor(self.import_visitor)
visitor.parse_and_visit_string_annotation(node.value)
(self.excess_wrapped_annotations if self.never_evaluates else self.wrapped_annotations).append(
WrappedAnnotation(node.lineno, node.col_offset, node.value, visitor.names, self.scope, self.type)
)

def visit_annotated_value(self, node: ast.expr) -> None:
"""Visit a value expression of `typing.Annotated[type, value]`."""
previous_context = self.import_visitor.in_soft_use_context
self.import_visitor.in_soft_use_context = True
self.import_visitor.visit(node)
self.import_visitor.in_soft_use_context = previous_context


class ImportVisitor(
DunderAllMixin, AttrsMixin, InjectorMixin, FastAPIMixin, PydanticMixin, SQLAlchemyMixin, ast.NodeVisitor
Expand Down Expand Up @@ -932,8 +977,12 @@ def __init__(
#: List of all names and ids, except type declarations
self.uses: dict[str, list[tuple[ast.AST, Scope]]] = defaultdict(list)

#: Contains a set of all names to be treated like soft-uses.
# i.e. we don't know if it will be used at runtime or not
self.soft_uses: set[str] = set()
Daverball marked this conversation as resolved.
Show resolved Hide resolved

#: Handles logic of visiting annotation nodes
self.annotation_visitor = ImportAnnotationVisitor()
self.annotation_visitor = ImportAnnotationVisitor(self)

#: Whether there is a `from __futures__ import annotations` is present in the file
self.futures_annotation: Optional[bool] = None
Expand All @@ -951,6 +1000,10 @@ def __init__(
#: For tracking which comprehension/IfExp we're currently inside of
self.active_context: Optional[Comprehension | ast.IfExp] = None

#: Whether or not we're in a context where uses count as soft-uses.
# E.g. the value expression of `typing.Annotated[type, value]`
self.in_soft_use_context: bool = False

@contextmanager
def create_scope(self, node: ast.ClassDef | Function, is_head: bool = True) -> Iterator[Scope]:
"""Create a new scope."""
Expand Down Expand Up @@ -1853,7 +1906,7 @@ def unused_imports(self) -> Flake8Generator:
}

all_imports = {name for name, imp in self.visitor.imports.items() if not imp.exempt}
unused_imports = all_imports - self.visitor.names - self.visitor.mapped_names
unused_imports = all_imports - self.visitor.names - self.visitor.soft_uses
used_imports = all_imports - unused_imports
already_imported_modules = [self.visitor.imports[name].module for name in used_imports]
annotation_names = (
Expand Down
4 changes: 4 additions & 0 deletions flake8_type_checking/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,9 @@ def lineno(self) -> int:
def col_offset(self) -> int:
pass

class SupportsIsTyping(Protocol):
def is_typing(self, node: ast.AST, symbol: str) -> bool:
pass


ImportTypeValue = Literal['APPLICATION', 'THIRD_PARTY', 'BUILTIN', 'FUTURE']
54 changes: 34 additions & 20 deletions tests/test_name_extraction.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,52 @@
import ast
import sys

import pytest

from flake8_type_checking.checker import StringAnnotationVisitor
from flake8_type_checking.checker import ImportVisitor, StringAnnotationVisitor

examples = [
('', set()),
('invalid_syntax]', set()),
('int', {'int'}),
('dict[str, int]', {'dict', 'str', 'int'}),
('', set(), set()),
('invalid_syntax]', set(), set()),
('int', {'int'}, set()),
('dict[str, int]', {'dict', 'str', 'int'}, set()),
# make sure literals don't add names for their contents
('Literal["a"]', {'Literal'}),
("Literal['a']", {'Literal'}),
('Literal[0]', {'Literal'}),
('Literal[1.0]', {'Literal'}),
('Literal[True]', {'Literal'}),
('T | S', {'T', 'S'}),
('Union[Dict[str, Any], Literal["Foo", "Bar"], _T]', {'Union', 'Dict', 'str', 'Any', 'Literal', '_T'}),
('Literal["a"]', {'Literal'}, set()),
("Literal['a']", {'Literal'}, set()),
('Literal[0]', {'Literal'}, set()),
('Literal[1.0]', {'Literal'}, set()),
('Literal[True]', {'Literal'}, set()),
('L[a]', {'L'}, set()),
('T | S', {'T', 'S'}, set()),
('Union[Dict[str, Any], Literal["Foo", "Bar"], _T]', {'Union', 'Dict', 'str', 'Any', 'Literal', '_T'}, set()),
# for attribute access only everything up to the first dot should count
# this matches the behavior of add_annotation
('datetime.date | os.path.sep', {'datetime', 'os'}),
('Nested["str"]', {'Nested', 'str'}),
('Annotated[str, validator]', {'Annotated', 'str'}),
('Annotated[str, "bool"]', {'Annotated', 'str'}),
('datetime.date | os.path.sep', {'datetime', 'os'}, set()),
('Nested["str"]', {'Nested', 'str'}, set()),
('Annotated[str, validator(int, 5)]', {'Annotated', 'str'}, {'validator', 'int'}),
('Annotated[str, "bool"]', {'Annotated', 'str'}, set()),
]

if sys.version_info >= (3, 11):
examples.extend([
('*Ts', {'Ts'}),
('*Ts', {'Ts'}, set()),
])


@pytest.mark.parametrize(('example', 'expected'), examples)
def test_name_extraction(example, expected):
visitor = StringAnnotationVisitor()
@pytest.mark.parametrize(('example', 'expected', 'soft_uses'), examples)
def test_name_extraction(example, expected, soft_uses):
import_visitor = ImportVisitor(
cwd='fake cwd', # type: ignore[arg-type]
pydantic_enabled=False,
fastapi_enabled=False,
fastapi_dependency_support_enabled=False,
cattrs_enabled=False,
sqlalchemy_enabled=False,
sqlalchemy_mapped_dotted_names=[],
injector_enabled=False,
pydantic_enabled_baseclass_passlist=[],
)
import_visitor.visit(ast.parse('from typing import Annotated, Literal, Literal as L'))
visitor = StringAnnotationVisitor(import_visitor)
visitor.parse_and_visit_string_annotation(example)
assert visitor.names == expected
19 changes: 19 additions & 0 deletions tests/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,25 @@ class User:
assert _get_error(example, error_code_filter='TC002', type_checking_sqlalchemy_enabled=True) == set()


def test_mapped_soft_uses():
"""
Everything inside Mapped is a soft-use, including `Annotated` value expresions
as such we can't trigger a TC002 here, despite the only uses being inside
type annotations.
"""
example = textwrap.dedent('''
from foo import Bar, Gt
from sqlalchemy.orm import Mapped
from typing import Annotated

class User:
number: Mapped[Annotated[Bar, Gt(2)]]
bar: Bar
validator: Gt
''')
assert _get_error(example, error_code_filter='TC002', type_checking_sqlalchemy_enabled=True) == set()


def test_mapped_use_without_runtime_import():
"""
Mapped must be available at runtime, so even if it is inside a wrapped annotation
Expand Down
6 changes: 5 additions & 1 deletion tests/test_tc001_to_tc003.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,25 +203,29 @@ class Migration:
'''),
set(),
),
# Annotated soft use
(
textwrap.dedent(f'''
from typing import Annotated

from {import_} import Depends

x: Annotated[str, Depends]
y: Depends
'''),
set(),
),
# This is not a soft-use, it's just a plain string
(
textwrap.dedent(f'''
from typing import Annotated

from {import_} import Depends

x: Annotated[str, "Depends"]
y: Depends
'''),
set(),
{'4:0 ' + ERROR.format(module=f'{import_}.Depends')},
),
]

Expand Down
Loading