Skip to content

Commit

Permalink
fix: Injector plugin requires all annotations in the injected function
Browse files Browse the repository at this point in the history
  • Loading branch information
Daverball authored Dec 13, 2024
1 parent 7ce3b36 commit c81ff55
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 39 deletions.
27 changes: 16 additions & 11 deletions flake8_type_checking/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,9 @@ class InjectorMixin:
def visit(self, node: ast.AST) -> ast.AST: # noqa: D102
...

def lookup_full_name(self, node: ast.AST) -> str | None: # noqa: D102
...

def visit_FunctionDef(self, node: FunctionDef) -> None:
"""Remove and map function arguments and returns."""
super().visit_FunctionDef(node) # type: ignore[misc]
Expand All @@ -487,6 +490,12 @@ def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
if self.injector_enabled:
self.handle_injector_declaration(node)

def _has_injected_annotation(self, node: AsyncFunctionDef | FunctionDef) -> bool:
return any(
isinstance(expr, ast.Subscript) and self.lookup_full_name(expr.value) == 'injector.Inject'
for expr in iter_function_annotation_nodes(node)
)

def handle_injector_declaration(self, node: AsyncFunctionDef | FunctionDef) -> None:
"""
Adjust for injector declaration setting.
Expand All @@ -496,17 +505,11 @@ def handle_injector_declaration(self, node: AsyncFunctionDef | FunctionDef) -> N
To achieve this, we just visit the annotations to register them as "uses".
"""
for path in [node.args.args, node.args.kwonlyargs]:
for argument in path:
if hasattr(argument, 'annotation') and argument.annotation:
annotation = argument.annotation
if not hasattr(annotation, 'value'):
continue
value = annotation.value
if hasattr(value, 'id') and value.id == 'Inject':
self.visit(argument.annotation)
if hasattr(value, 'attr') and value.attr == 'Inject':
self.visit(argument.annotation)
if not self._has_injected_annotation(node):
return

for expr in iter_function_annotation_nodes(node):
self.visit(expr)


class FastAPIMixin:
Expand Down Expand Up @@ -592,6 +595,8 @@ def visit_FunctionDef(self, node: FunctionDef) -> None:
def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
"""Remove and map function arguments and returns."""
super().visit_AsyncFunctionDef(node) # type: ignore[misc]
if self.in_type_checking_block(node.lineno, node.col_offset):
return
if self.has_singledispatch_decorator(node):
for expr in iter_function_annotation_nodes(node):
self.visit(expr)
Expand Down
38 changes: 10 additions & 28 deletions tests/test_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, service: Inject[Service]) -> None:
@pytest.mark.parametrize(
('enabled', 'expected'),
[
(True, {'4:0 ' + TC002.format(module='other_dependency.OtherDependency')}),
(True, set()),
(
False,
{
Expand All @@ -65,8 +65,8 @@ def __init__(self, service: Inject[Service]) -> None:
),
],
)
def test_injector_option_only_allows_injected_dependencies(enabled, expected):
"""Whenever an injector option is enabled, only injected dependencies should be ignored."""
def test_injector_option_all_annotations_in_function_are_runtime_dependencies(enabled, expected):
"""Whenever an argument is injected, all the other annotations are runtime required too."""
example = textwrap.dedent(
'''
from injector import Inject
Expand All @@ -82,38 +82,20 @@ def __init__(self, service: Inject[Service], other: OtherDependency) -> None:
assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected


@pytest.mark.parametrize(
('enabled', 'expected'),
[
(True, {'4:0 ' + TC002.format(module='other_dependency.OtherDependency')}),
(
False,
{
'2:0 ' + TC002.format(module='injector.Inject'),
'3:0 ' + TC002.format(module='services.Service'),
'4:0 ' + TC002.format(module='other_dependency.OtherDependency'),
},
),
],
)
def test_injector_option_only_allows_injector_slices(enabled, expected):
"""
Whenever an injector option is enabled, only injected dependencies should be ignored,
not any dependencies with slices.
"""
def test_injector_option_require_injections_under_unpack():
"""Whenever an injector option is enabled, injected dependencies should be ignored, even if unpacked."""
example = textwrap.dedent(
"""
from typing import Unpack
from injector import Inject
from services import Service
from other_dependency import OtherDependency
from services import ServiceKwargs
class X:
def __init__(self, service: Inject[Service], other_deps: list[OtherDependency]) -> None:
def __init__(self, service: Inject[Service], **kwargs: Unpack[ServiceKwargs]) -> None:
self.service = service
self.other_deps = other_deps
self.args = args
"""
)
assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected
assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=True) == set()


@pytest.mark.parametrize(
Expand Down

0 comments on commit c81ff55

Please sign in to comment.