diff --git a/flake8_type_checking/checker.py b/flake8_type_checking/checker.py index d98e883..97cac09 100644 --- a/flake8_type_checking/checker.py +++ b/flake8_type_checking/checker.py @@ -475,6 +475,9 @@ class InjectorMixin: def visit(self, node: ast.AST) -> ast.AST: # noqa: D102 ... + def lookup_full_name(self, node: ast.AST) -> str | None: # noqa: D102 + ... + def visit_FunctionDef(self, node: FunctionDef) -> None: """Remove and map function arguments and returns.""" super().visit_FunctionDef(node) # type: ignore[misc] @@ -487,6 +490,12 @@ def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None: if self.injector_enabled: self.handle_injector_declaration(node) + def _has_injected_annotation(self, node: AsyncFunctionDef | FunctionDef) -> bool: + return any( + isinstance(expr, ast.Subscript) and self.lookup_full_name(expr.value) == 'injector.Inject' + for expr in iter_function_annotation_nodes(node) + ) + def handle_injector_declaration(self, node: AsyncFunctionDef | FunctionDef) -> None: """ Adjust for injector declaration setting. @@ -496,17 +505,11 @@ def handle_injector_declaration(self, node: AsyncFunctionDef | FunctionDef) -> N To achieve this, we just visit the annotations to register them as "uses". """ - for path in [node.args.args, node.args.kwonlyargs]: - for argument in path: - if hasattr(argument, 'annotation') and argument.annotation: - annotation = argument.annotation - if not hasattr(annotation, 'value'): - continue - value = annotation.value - if hasattr(value, 'id') and value.id == 'Inject': - self.visit(argument.annotation) - if hasattr(value, 'attr') and value.attr == 'Inject': - self.visit(argument.annotation) + if not self._has_injected_annotation(node): + return + + for expr in iter_function_annotation_nodes(node): + self.visit(expr) class FastAPIMixin: @@ -592,6 +595,8 @@ def visit_FunctionDef(self, node: FunctionDef) -> None: def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None: """Remove and map function arguments and returns.""" super().visit_AsyncFunctionDef(node) # type: ignore[misc] + if self.in_type_checking_block(node.lineno, node.col_offset): + return if self.has_singledispatch_decorator(node): for expr in iter_function_annotation_nodes(node): self.visit(expr) diff --git a/tests/test_injector.py b/tests/test_injector.py index 6ea5f75..d9267d0 100644 --- a/tests/test_injector.py +++ b/tests/test_injector.py @@ -54,7 +54,7 @@ def __init__(self, service: Inject[Service]) -> None: @pytest.mark.parametrize( ('enabled', 'expected'), [ - (True, {'4:0 ' + TC002.format(module='other_dependency.OtherDependency')}), + (True, set()), ( False, { @@ -65,8 +65,8 @@ def __init__(self, service: Inject[Service]) -> None: ), ], ) -def test_injector_option_only_allows_injected_dependencies(enabled, expected): - """Whenever an injector option is enabled, only injected dependencies should be ignored.""" +def test_injector_option_all_annotations_in_function_are_runtime_dependencies(enabled, expected): + """Whenever an argument is injected, all the other annotations are runtime required too.""" example = textwrap.dedent( ''' from injector import Inject @@ -82,38 +82,20 @@ def __init__(self, service: Inject[Service], other: OtherDependency) -> None: assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected -@pytest.mark.parametrize( - ('enabled', 'expected'), - [ - (True, {'4:0 ' + TC002.format(module='other_dependency.OtherDependency')}), - ( - False, - { - '2:0 ' + TC002.format(module='injector.Inject'), - '3:0 ' + TC002.format(module='services.Service'), - '4:0 ' + TC002.format(module='other_dependency.OtherDependency'), - }, - ), - ], -) -def test_injector_option_only_allows_injector_slices(enabled, expected): - """ - Whenever an injector option is enabled, only injected dependencies should be ignored, - not any dependencies with slices. - """ +def test_injector_option_require_injections_under_unpack(): + """Whenever an injector option is enabled, injected dependencies should be ignored, even if unpacked.""" example = textwrap.dedent( """ + from typing import Unpack from injector import Inject - from services import Service - from other_dependency import OtherDependency - + from services import ServiceKwargs class X: - def __init__(self, service: Inject[Service], other_deps: list[OtherDependency]) -> None: + def __init__(self, service: Inject[Service], **kwargs: Unpack[ServiceKwargs]) -> None: self.service = service - self.other_deps = other_deps + self.args = args """ ) - assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected + assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=True) == set() @pytest.mark.parametrize(