diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fa1fbc5..9cff073 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,6 +40,8 @@ repos: rev: v1.11.2 hooks: - id: mypy + # uses py311 syntax, mypy configured for py39 + exclude: tests/eval_files/async123.py - repo: https://github.com/RobertCraigie/pyright-python rev: v1.1.384 diff --git a/flake8_async/visitors/__init__.py b/flake8_async/visitors/__init__.py index 0b05011..f1b6199 100644 --- a/flake8_async/visitors/__init__.py +++ b/flake8_async/visitors/__init__.py @@ -36,6 +36,7 @@ visitor105, visitor111, visitor118, + visitor123, visitor_utility, visitors, ) diff --git a/flake8_async/visitors/visitor123.py b/flake8_async/visitors/visitor123.py new file mode 100644 index 0000000..0fd99bb --- /dev/null +++ b/flake8_async/visitors/visitor123.py @@ -0,0 +1,110 @@ +"""foo.""" + +from __future__ import annotations + +import ast +from typing import TYPE_CHECKING, Any + +from .flake8asyncvisitor import Flake8AsyncVisitor +from .helpers import error_class + +if TYPE_CHECKING: + from collections.abc import Mapping + + +@error_class +class Visitor123(Flake8AsyncVisitor): + error_codes: Mapping[str, str] = { + "ASYNC123": ( + "Raising a child exception of an exception group loses" + " context, cause, and/or traceback of the exception inside the group." + ) + } + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.try_star = False + self.exception_group_names: set[str] = set() + self.child_exception_list_names: set[str] = set() + self.child_exception_names: set[str] = set() + + def _is_exception_group(self, node: ast.expr) -> bool: + return ( + (isinstance(node, ast.Name) and node.id in self.exception_group_names) + or ( + # a child exception might be an ExceptionGroup + self._is_child_exception(node) + ) + or ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and self._is_exception_group(node.func.value) + and node.func.attr in ("subgroup", "split") + ) + ) + + def _is_exception_list(self, node: ast.expr | None) -> bool: + return ( + isinstance(node, ast.Name) and node.id in self.child_exception_list_names + ) or ( + isinstance(node, ast.Attribute) + and node.attr == "exceptions" + and self._is_exception_group(node.value) + ) + + def _is_child_exception(self, node: ast.expr | None) -> bool: + return ( + isinstance(node, ast.Name) and node.id in self.child_exception_names + ) or (isinstance(node, ast.Subscript) and self._is_exception_list(node.value)) + + def visit_Raise(self, node: ast.Raise): + if self._is_child_exception(node.exc): + self.error(node) + + def visit_ExceptHandler(self, node: ast.ExceptHandler): + self.save_state( + node, + "exception_group_names", + "child_exception_list_names", + "child_exception_names", + copy=True, + ) + if node.name is None or ( + not self.try_star + and (node.type is None or "ExceptionGroup" not in ast.unparse(node.type)) + ): + self.novisit = True + return + self.exception_group_names = {node.name} + + # ast.TryStar added in py311 + def visit_TryStar(self, node: ast.TryStar): # type: ignore[name-defined] + self.save_state(node, "try_star", copy=False) + self.try_star = True + + def visit_Assign(self, node: ast.Assign | ast.AnnAssign): + if node.value is None or not self.exception_group_names: + return + targets = (node.target,) if isinstance(node, ast.AnnAssign) else node.targets + if self._is_child_exception(node.value): + for target in targets: + if isinstance(target, ast.Name): + self.child_exception_names.add(target.id) + elif self._is_exception_list(node.value): + if len(targets) == 1 and isinstance(targets[0], ast.Name): + self.child_exception_list_names.add(targets[0].id) + # unpacking tuples and Starred and shit. Not implemented + elif self._is_exception_group(node.value): + for target in targets: + if isinstance(target, ast.Name): + self.exception_group_names.add(target.id) + elif isinstance(target, ast.Tuple): + for t in target.elts: + if isinstance(t, ast.Name): + self.exception_group_names.add(t.id) + + visit_AnnAssign = visit_Assign + + def visit_For(self, node: ast.For): + if self._is_exception_list(node.iter) and isinstance(node.target, ast.Name): + self.child_exception_names.add(node.target.id) diff --git a/tests/eval_files/async123.py b/tests/eval_files/async123.py new file mode 100644 index 0000000..36c5208 --- /dev/null +++ b/tests/eval_files/async123.py @@ -0,0 +1,116 @@ +import copy +import sys +from typing import Any + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup, ExceptionGroup + + +def condition() -> bool: + return True + + +def any_fun(arg: Exception) -> Exception: + return arg + + +try: + ... +except ExceptionGroup as e: + if condition(): + raise e.exceptions[0] # error: 8 + elif condition(): + raise copy.copy(e.exceptions[0]) # safe + elif condition(): + raise copy.deepcopy(e.exceptions[0]) # safe + else: + raise any_fun(e.exceptions[0]) # safe +try: + ... +except BaseExceptionGroup as e: + raise e.exceptions[0] # error: 4 +try: + ... +except ExceptionGroup as e: + my_e = e.exceptions[0] + raise my_e # error: 4 +try: + ... +except ExceptionGroup as e: + excs = e.exceptions + my_e = excs[0] + raise my_e # error: 4 +try: + ... +except ExceptionGroup as e: + excs_2 = e.subgroup(bool) + if excs_2: + raise excs_2.exceptions[0] # error: 8 +try: + ... +except ExceptionGroup as e: + excs_1, excs_2 = e.split(bool) + if excs_1: + raise excs_1.exceptions[0] # error: 8 + if excs_2: + raise excs_2.exceptions[0] # error: 8 + +try: + ... +except ExceptionGroup as e: + f = e + raise f.exceptions[0] # error: 4 +try: + ... +except ExceptionGroup as e: + excs = e.exceptions + excs2 = excs + raise excs2[0] # error: 4 +try: + ... +except ExceptionGroup as e: + my_exc = e.exceptions[0] + my_exc2 = my_exc + raise my_exc2 # error: 4 + +try: + ... +except* Exception as e: + raise e.exceptions[0] # error: 4 + +try: + ... +except ExceptionGroup as e: + raise e.exceptions[0].exceptions[0] # error: 4 +try: + ... +except ExceptionGroup as e: + excs = e.exceptions + for exc in excs: + if ...: + raise exc # error: 12 + raise +try: + ... +except ExceptionGroup as e: + ff: ExceptionGroup[Exception] = e + raise ff.exceptions[0] # error: 4 +try: + ... +except ExceptionGroup as e: + raise e.subgroup(bool).exceptions[0] # type: ignore # error: 4 + +# not implemented +try: + ... +except ExceptionGroup as e: + a, *b = e.exceptions + raise a + +# not implemented +try: + ... +except ExceptionGroup as e: + x: Any = object() + x.y = e + raise x.y.exceptions[0] diff --git a/tests/test_flake8_async.py b/tests/test_flake8_async.py index 5219f27..9676bbc 100644 --- a/tests/test_flake8_async.py +++ b/tests/test_flake8_async.py @@ -482,6 +482,7 @@ def _parse_eval_file( # doesn't check for it "ASYNC121", "ASYNC122", + "ASYNC123", "ASYNC300", "ASYNC912", }