From 94f2a5ae891d571bd14b18c1a2d1ea15703779f3 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 3 May 2024 15:19:03 +0200 Subject: [PATCH 1/5] WIP: async912 --- flake8_async/visitors/helpers.py | 2 + flake8_async/visitors/visitor91x.py | 26 ++++- tests/autofix_files/async91x_autofix.py | 17 +++ tests/autofix_files/async91x_autofix.py.diff | 13 +-- tests/eval_files/async912.py | 108 +++++++++++++++++++ tests/eval_files/async912_asyncio.py | 19 ++++ tests/eval_files/async91x_autofix.py | 17 +++ tests/test_flake8_async.py | 3 +- 8 files changed, 195 insertions(+), 10 deletions(-) create mode 100644 tests/eval_files/async912.py create mode 100644 tests/eval_files/async912_asyncio.py diff --git a/flake8_async/visitors/helpers.py b/flake8_async/visitors/helpers.py index f8521b3b..46f2b2e7 100644 --- a/flake8_async/visitors/helpers.py +++ b/flake8_async/visitors/helpers.py @@ -320,6 +320,8 @@ class AttributeCall(NamedTuple): def with_has_call( node: cst.With, *names: str, base: Iterable[str] = ("trio", "anyio") ) -> list[AttributeCall]: + if isinstance(base, str): + base = (base,) res_list: list[AttributeCall] = [] for item in node.items: if res := m.extract( diff --git a/flake8_async/visitors/visitor91x.py b/flake8_async/visitors/visitor91x.py index 988bc3a9..24c0b64c 100644 --- a/flake8_async/visitors/visitor91x.py +++ b/flake8_async/visitors/visitor91x.py @@ -18,11 +18,13 @@ from ..base import Statement from .flake8asyncvisitor import Flake8AsyncVisitor_cst from .helpers import ( + cancel_scope_names, disabled_by_default, error_class_cst, fnmatch_qualified_name_cst, func_has_decorator, iter_guaranteed_once_cst, + with_has_call, ) if TYPE_CHECKING: @@ -243,6 +245,10 @@ class Visitor91X(Flake8AsyncVisitor_cst, CommonVisitors): "{0} from async iterable with no guaranteed checkpoint since {1.name} " "on line {1.lineno}." ), + "ASYNC912": ( + "CancelScope with no guaranteed checkpoint. This makes it potentially " + "impossible to cancel." + ), } def __init__(self, *args: Any, **kwargs: Any): @@ -420,8 +426,26 @@ def leave_Await( def visit_With_body(self, node: cst.With): if getattr(node, "asynchronous", None): self.uncheckpointed_statements = set() + if with_has_call(node, *cancel_scope_names) or with_has_call( + node, "timeout", "timeout_at", base="asyncio" + ): + pos = self.get_metadata(PositionProvider, node).start # pyright: ignore + line: int = pos.line # pyright: ignore + column: int = pos.column # pyright: ignore + self.uncheckpointed_statements.add(Statement("with", line, column)) + # self.uncheckpointed_statements.add(res[0]) + + def leave_With_body(self, node: cst.With): + pos = self.get_metadata(PositionProvider, node).start # pyright: ignore + line: int = pos.line # pyright: ignore + column: int = pos.column # pyright: ignore + s = Statement("with", line, column) + if s in self.uncheckpointed_statements: + self.error(node, error_code="ASYNC912") + self.uncheckpointed_statements.remove(s) - leave_With_body = visit_With_body + if getattr(node, "asynchronous", None): + self.uncheckpointed_statements = set() # error if no checkpoint since earlier yield or function entry def leave_Yield( diff --git a/tests/autofix_files/async91x_autofix.py b/tests/autofix_files/async91x_autofix.py index d7d99b3a..b572f833 100644 --- a/tests/autofix_files/async91x_autofix.py +++ b/tests/autofix_files/async91x_autofix.py @@ -9,6 +9,7 @@ # ARG --enable=ASYNC910,ASYNC911 from typing import Any + import trio @@ -124,3 +125,19 @@ async def async_func(): ... break [... for i in range(5)] return + + +# TODO: issue 240 +async def livelocks(): + while True: + ... + + +# this will autofix 910 by adding a checkpoint outside the loop +async def no_checkpoint(): # ASYNC910: 0, "exit", Statement("function definition", lineno) + while bar(): + try: + await trio.sleep("1") # type: ignore[arg-type] + except ValueError: + ... + await trio.lowlevel.checkpoint() diff --git a/tests/autofix_files/async91x_autofix.py.diff b/tests/autofix_files/async91x_autofix.py.diff index 2c84b107..5fa6c660 100644 --- a/tests/autofix_files/async91x_autofix.py.diff +++ b/tests/autofix_files/async91x_autofix.py.diff @@ -1,13 +1,5 @@ --- +++ -@@ x,6 x,7 @@ - # ARG --enable=ASYNC910,ASYNC911 - - from typing import Any -+import trio - - - def bar() -> Any: ... @@ x,30 x,38 @@ async def foo1(): # ASYNC910: 0, "exit", Statement("function definition", lineno) @@ -78,3 +70,8 @@ yield # ASYNC911: 8, "yield", Statement("function definition", lineno-2) # ASYNC911: 8, "yield", Statement("yield", lineno) async def bar(): +@@ x,3 x,4 @@ + await trio.sleep("1") # type: ignore[arg-type] + except ValueError: + ... ++ await trio.lowlevel.checkpoint() diff --git a/tests/eval_files/async912.py b/tests/eval_files/async912.py new file mode 100644 index 00000000..8b99aafb --- /dev/null +++ b/tests/eval_files/async912.py @@ -0,0 +1,108 @@ +# ASYNCIO_NO_ERROR +import trio + + +async def foo(): + with trio.move_on_after(0.1): # error: 4 + ... + with trio.move_on_at(0.1): # error: 4 + ... + with trio.fail_after(0.1): # error: 4 + ... + with trio.fail_at(0.1): # error: 4 + ... + with trio.CancelScope(0.1): # error: 4 + ... + + with open(""): + ... + + with trio.move_on_after(0.1): + await trio.lowlevel.checkpoint() + + with trio.move_on_after(0.1): # error: 4 + with trio.move_on_after(0.1): # error: 8 + ... + + with trio.move_on_after(0.1): # TODO: should probably raise an error? + with trio.move_on_after(0.1): + await trio.lowlevel.checkpoint() + + with trio.move_on_after(0.1): + await trio.lowlevel.checkpoint() + with trio.move_on_after(0.1): + await trio.lowlevel.checkpoint() + + with trio.move_on_after(0.1): + with trio.move_on_after(0.1): + await trio.lowlevel.checkpoint() + await trio.lowlevel.checkpoint() + + # TODO: should probably raise the error at the call, rather than at the with statement + # fmt: off + with ( # error: 4 + # a + # b + trio.move_on_after(0.1) + # c + ): + ... + + with ( # error: 4 + open(""), + trio.move_on_at(5), + open(""), + ): + ... + # fmt: on + + # TODO: only raises one error currently, can make it raise 2(?) + with ( # error: 4 + trio.move_on_after(0.1), + trio.fail_at(5), + ): + ... + + +# TODO: issue #240 +async def livelocks(): + with trio.move_on_after(0.1): # should error + while True: + try: + await trio.sleep("1") # type: ignore + except TypeError: + pass + + +def condition() -> bool: + return True + + +async def livelocks_2(): + with trio.move_on_after(0.1): # error: 4 + while condition(): + try: + await trio.sleep("1") # type: ignore + except TypeError: + pass + + +# TODO: add --async912-context-managers= +async def livelocks_3(): + import contextlib + + with trio.move_on_after(0.1): # should error + while True: + with contextlib.suppress(TypeError): + await trio.sleep("1") # type: ignore + + +# raises an error...? +with trio.move_on_after(10): # error: 0 + ... + + +# completely sync function ... is this something we care about? +def sync_func(): + with trio.move_on_after(10): + ... diff --git a/tests/eval_files/async912_asyncio.py b/tests/eval_files/async912_asyncio.py new file mode 100644 index 00000000..82b2547a --- /dev/null +++ b/tests/eval_files/async912_asyncio.py @@ -0,0 +1,19 @@ +# BASE_LIBRARY asyncio +# ANYIO_NO_ERROR +# TRIO_NO_ERROR + +# timeout[_at] added in py3.11 +# mypy: disable-error-code=attr-defined + +import asyncio + + +async def foo(): + async with asyncio.timeout(10): # error: 4 + ... + async with asyncio.timeout(10): + await foo() + async with asyncio.timeout_at(10): # error: 4 + ... + async with asyncio.timeout_at(10): + await foo() diff --git a/tests/eval_files/async91x_autofix.py b/tests/eval_files/async91x_autofix.py index de650311..58bb2a4c 100644 --- a/tests/eval_files/async91x_autofix.py +++ b/tests/eval_files/async91x_autofix.py @@ -10,6 +10,8 @@ from typing import Any +import trio + def bar() -> Any: ... @@ -109,3 +111,18 @@ async def async_func(): ... break [... for i in range(5)] return + + +# TODO: issue 240 +async def livelocks(): + while True: + ... + + +# this will autofix 910 by adding a checkpoint outside the loop +async def no_checkpoint(): # ASYNC910: 0, "exit", Statement("function definition", lineno) + while bar(): + try: + await trio.sleep("1") # type: ignore[arg-type] + except ValueError: + ... diff --git a/tests/test_flake8_async.py b/tests/test_flake8_async.py index a345a5ed..81a77c5a 100644 --- a/tests/test_flake8_async.py +++ b/tests/test_flake8_async.py @@ -452,6 +452,7 @@ def _parse_eval_file( "ASYNC116", "ASYNC117", "ASYNC118", + "ASYNC912", } @@ -479,7 +480,7 @@ def visit_AsyncFor(self, node: ast.AsyncFor): return self.replace_async(node, ast.For, node.target, node.iter) -@pytest.mark.parametrize(("test", "path"), test_files) +@pytest.mark.parametrize(("test", "path"), test_files, ids=[f[0] for f in test_files]) def test_noerror_on_sync_code(test: str, path: Path): if any(e in test for e in error_codes_ignored_when_checking_transformed_sync_code): return From 3114c7e98e1de225b2126526b999c2823c4b77af Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sat, 4 May 2024 16:54:40 +0200 Subject: [PATCH 2/5] shit is working --- CHANGELOG.md | 2 + README.md | 1 + flake8_async/__init__.py | 2 +- flake8_async/visitors/__init__.py | 1 - flake8_async/visitors/flake8asyncvisitor.py | 12 ++- flake8_async/visitors/helpers.py | 4 + flake8_async/visitors/visitor100.py | 90 ---------------- flake8_async/visitors/visitor91x.py | 108 +++++++++++++------ tests/autofix_files/async100.py | 7 +- tests/autofix_files/async100_asyncio.py | 25 +++++ tests/autofix_files/async100_asyncio.py.diff | 17 +++ tests/autofix_files/async910.py | 1 + tests/autofix_files/async911.py | 1 + tests/autofix_files/async91x_autofix.py | 2 + tests/autofix_files/noqa_testing.py | 3 + tests/eval_files/async100.py | 7 +- tests/eval_files/async100_asyncio.py | 12 ++- tests/eval_files/async910.py | 1 + tests/eval_files/async911.py | 1 + tests/eval_files/async912.py | 74 ++++++++++--- tests/eval_files/async912_asyncio.py | 22 +++- tests/eval_files/async91x_autofix.py | 2 + tests/eval_files/noqa_testing.py | 3 + tests/test_flake8_async.py | 49 ++++++--- 24 files changed, 280 insertions(+), 167 deletions(-) delete mode 100644 flake8_async/visitors/visitor100.py create mode 100644 tests/autofix_files/async100_asyncio.py create mode 100644 tests/autofix_files/async100_asyncio.py.diff diff --git a/CHANGELOG.md b/CHANGELOG.md index dd2d283b..0d4a66f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ # Changelog *[CalVer, YY.month.patch](https://calver.org/)* +## 24.5.1 +- Add ASYNC912: no checkpoints in with statement are guaranteed to run. ## 24.4.1 - ASYNC91X fix internal error caused by multiple `try/except` incorrectly sharing state. diff --git a/README.md b/README.md index b98f4dc4..071b61fd 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ Note: 22X, 23X and 24X has not had asyncio-specific suggestions written. - **ASYNC910**: Exit or `return` from async function with no guaranteed checkpoint or exception since function definition. You might want to enable this on a codebase to make it easier to reason about checkpoints, and make the logic of ASYNC911 correct. - **ASYNC911**: Exit, `yield` or `return` from async iterable with no guaranteed checkpoint since possible function entry (yield or function definition) Checkpoints are `await`, `async for`, and `async with` (on one of enter/exit). +- **ASYNC912**: TODO: write ### Removed Warnings - **TRIOxxx**: All error codes are now renamed ASYNCxxx diff --git a/flake8_async/__init__.py b/flake8_async/__init__.py index e278c392..b808ba5d 100644 --- a/flake8_async/__init__.py +++ b/flake8_async/__init__.py @@ -37,7 +37,7 @@ # CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1" -__version__ = "24.4.1" +__version__ = "24.5.1" # taken from https://github.com/Zac-HD/shed diff --git a/flake8_async/visitors/__init__.py b/flake8_async/visitors/__init__.py index bd858eca..0b05011f 100644 --- a/flake8_async/visitors/__init__.py +++ b/flake8_async/visitors/__init__.py @@ -30,7 +30,6 @@ from . import ( visitor2xx, visitor91x, - visitor100, visitor101, visitor102, visitor103_104, diff --git a/flake8_async/visitors/flake8asyncvisitor.py b/flake8_async/visitors/flake8asyncvisitor.py index 160bedf9..92b4310c 100644 --- a/flake8_async/visitors/flake8asyncvisitor.py +++ b/flake8_async/visitors/flake8asyncvisitor.py @@ -98,7 +98,11 @@ def error( ), "No error code defined, but class has multiple codes" error_code = next(iter(self.error_codes)) # don't emit an error if this code is disabled in a multi-code visitor - elif strip_error_subidentifier(error_code) not in self.options.enabled_codes: + elif ( + (ec_no_sub := strip_error_subidentifier(error_code)) + not in self.options.enabled_codes + and ec_no_sub not in self.options.autofix_codes + ): return self.__state.problems.append( @@ -217,7 +221,11 @@ def error( error_code = next(iter(self.error_codes)) # don't emit an error if this code is disabled in a multi-code visitor # TODO: write test for only one of 910/911 enabled/autofixed - elif strip_error_subidentifier(error_code) not in self.options.enabled_codes: + elif ( + (ec_no_sub := strip_error_subidentifier(error_code)) + not in self.options.enabled_codes + and ec_no_sub not in self.options.autofix_codes + ): return False # pragma: no cover if self.is_noqa(node, error_code): diff --git a/flake8_async/visitors/helpers.py b/flake8_async/visitors/helpers.py index 46f2b2e7..352ec144 100644 --- a/flake8_async/visitors/helpers.py +++ b/flake8_async/visitors/helpers.py @@ -58,6 +58,10 @@ def disabled_by_default(error_class: type[T_EITHER]) -> type[T_EITHER]: return error_class +def disable_codes_by_default(*codes: str) -> None: + default_disabled_error_codes.extend(codes) + + def utility_visitor(c: type[T]) -> type[T]: assert not hasattr(c, "error_codes") c.error_codes = {} diff --git a/flake8_async/visitors/visitor100.py b/flake8_async/visitors/visitor100.py deleted file mode 100644 index 345f8926..00000000 --- a/flake8_async/visitors/visitor100.py +++ /dev/null @@ -1,90 +0,0 @@ -"""Contains visitor for ASYNC100. - -A `with trio.fail_after(...):` or `with trio.move_on_after(...):` -context does not contain any `await` statements. This makes it pointless, as -the timeout can only be triggered by a checkpoint. -Checkpoints on Await, Async For and Async With -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -import libcst as cst -import libcst.matchers as m - -from .flake8asyncvisitor import Flake8AsyncVisitor_cst -from .helpers import ( - AttributeCall, - error_class_cst, - flatten_preserving_comments, - with_has_call, -) - -if TYPE_CHECKING: - from collections.abc import Mapping - - -@error_class_cst -class Visitor100_libcst(Flake8AsyncVisitor_cst): - error_codes: Mapping[str, str] = { - "ASYNC100": ( - "{0}.{1} context contains no checkpoints, remove the context or add" - " `await {0}.lowlevel.checkpoint()`." - ), - } - - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) - self.has_checkpoint_stack: list[bool] = [] - self.node_dict: dict[cst.With, list[AttributeCall]] = {} - - def checkpoint(self) -> None: - # Set the whole stack to True. - self.has_checkpoint_stack = [True] * len(self.has_checkpoint_stack) - - def visit_With(self, node: cst.With) -> None: - if m.matches(node, m.With(asynchronous=m.Asynchronous())): - self.checkpoint() - if res := with_has_call( - node, "fail_after", "fail_at", "move_on_after", "move_on_at", "CancelScope" - ): - self.node_dict[node] = res - - self.has_checkpoint_stack.append(False) - else: - self.has_checkpoint_stack.append(True) - - def leave_With( - self, original_node: cst.With, updated_node: cst.With - ) -> cst.BaseStatement | cst.FlattenSentinel[cst.BaseStatement]: - if not self.has_checkpoint_stack.pop(): - autofix = len(updated_node.items) == 1 - for res in self.node_dict[original_node]: - autofix &= self.error( - res.node, res.base, res.function - ) and self.should_autofix(res.node) - - if autofix: - return flatten_preserving_comments(updated_node) - - return updated_node - - def visit_For(self, node: cst.For): - if node.asynchronous is not None: - self.checkpoint() - - def visit_Await(self, node: cst.Await | cst.Yield): - self.checkpoint() - - visit_Yield = visit_Await - - def visit_FunctionDef(self, node: cst.FunctionDef): - self.save_state(node, "has_checkpoint_stack", copy=True) - self.has_checkpoint_stack = [] - - def leave_FunctionDef( - self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef - ) -> cst.FunctionDef: - self.restore_state(original_node) - return updated_node diff --git a/flake8_async/visitors/visitor91x.py b/flake8_async/visitors/visitor91x.py index 24c0b64c..16286ddd 100644 --- a/flake8_async/visitors/visitor91x.py +++ b/flake8_async/visitors/visitor91x.py @@ -18,9 +18,11 @@ from ..base import Statement from .flake8asyncvisitor import Flake8AsyncVisitor_cst from .helpers import ( + AttributeCall, cancel_scope_names, - disabled_by_default, + disable_codes_by_default, error_class_cst, + flatten_preserving_comments, fnmatch_qualified_name_cst, func_has_decorator, iter_guaranteed_once_cst, @@ -31,8 +33,11 @@ from collections.abc import Mapping, Sequence +class ArtificialStatement(Statement): ... + + # Statement injected at the start of loops to track missed checkpoints. -ARTIFICIAL_STATEMENT = Statement("artificial", -1) +ARTIFICIAL_STATEMENT = ArtificialStatement("artificial", -1) def func_empty_body(node: cst.FunctionDef) -> bool: @@ -233,8 +238,10 @@ def leave_Yield( leave_Return = leave_Yield # type: ignore +disable_codes_by_default("ASYNC910", "ASYNC911", "ASYNC912") + + @error_class_cst -@disabled_by_default class Visitor91X(Flake8AsyncVisitor_cst, CommonVisitors): error_codes: Mapping[str, str] = { "ASYNC910": ( @@ -249,6 +256,10 @@ class Visitor91X(Flake8AsyncVisitor_cst, CommonVisitors): "CancelScope with no guaranteed checkpoint. This makes it potentially " "impossible to cancel." ), + "ASYNC100": ( + "{0}.{1} context contains no checkpoints, remove the context or add" + " `await {0}.lowlevel.checkpoint()`." + ), } def __init__(self, *args: Any, **kwargs: Any): @@ -262,15 +273,24 @@ def __init__(self, *args: Any, **kwargs: Any): self.loop_state = LoopState() self.try_state = TryState() + # ASYNC100 + self.has_checkpoint_stack: list[bool] = [] + self.node_dict: dict[cst.With, list[AttributeCall]] = {} + def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool: + if code is None: + code = "ASYNC911" if self.has_yield else "ASYNC910" + return ( not self.noautofix - and super().should_autofix( - node, "ASYNC911" if self.has_yield else "ASYNC910" - ) + and super().should_autofix(node, code) and self.library != ("asyncio",) ) + def checkpoint(self) -> None: + self.uncheckpointed_statements = set() + self.has_checkpoint_stack = [True] * len(self.has_checkpoint_stack) + def checkpoint_statement(self) -> cst.SimpleStatementLine: return checkpoint_statement(self.library[0]) @@ -289,9 +309,11 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: "uncheckpointed_statements", "loop_state", "try_state", + "has_checkpoint_stack", copy=True, ) self.uncheckpointed_statements = set() + self.has_checkpoint_stack = [] self.has_yield = self.safe_decorator = False self.loop_state = LoopState() @@ -365,7 +387,7 @@ def check_function_exit( any_errors = False # raise the actual errors for statement in self.uncheckpointed_statements: - if statement == ARTIFICIAL_STATEMENT: + if isinstance(statement, ArtificialStatement): continue any_errors |= self.error_91x(original_node, statement) @@ -382,6 +404,7 @@ def leave_Return( self.add_statement = self.checkpoint_statement() # avoid duplicate error messages self.uncheckpointed_statements = set() + # we don't treat it as a checkpoint for ASYNC100 # return original node to avoid problems with identity equality assert original_node.deep_equals(updated_node) @@ -392,7 +415,7 @@ def error_91x( node: cst.Return | cst.FunctionDef | cst.Yield, statement: Statement, ) -> bool: - assert statement != ARTIFICIAL_STATEMENT + assert not isinstance(statement, ArtificialStatement) if isinstance(node, cst.FunctionDef): msg = "exit" @@ -413,7 +436,7 @@ def leave_Await( # so only set checkpoint after the await node # all nodes are now checkpointed - self.uncheckpointed_statements = set() + self.checkpoint() return updated_node # raising exception means we don't need to checkpoint so we can treat it as one @@ -425,27 +448,49 @@ def leave_Await( # missing-checkpoint warning when there might in fact be one (i.e. a false alarm). def visit_With_body(self, node: cst.With): if getattr(node, "asynchronous", None): - self.uncheckpointed_statements = set() - if with_has_call(node, *cancel_scope_names) or with_has_call( - node, "timeout", "timeout_at", base="asyncio" + self.checkpoint() + if res := ( + with_has_call(node, *cancel_scope_names) + or with_has_call(node, "timeout", "timeout_at", base="asyncio") ): pos = self.get_metadata(PositionProvider, node).start # pyright: ignore line: int = pos.line # pyright: ignore column: int = pos.column # pyright: ignore - self.uncheckpointed_statements.add(Statement("with", line, column)) - # self.uncheckpointed_statements.add(res[0]) - - def leave_With_body(self, node: cst.With): - pos = self.get_metadata(PositionProvider, node).start # pyright: ignore - line: int = pos.line # pyright: ignore - column: int = pos.column # pyright: ignore - s = Statement("with", line, column) - if s in self.uncheckpointed_statements: - self.error(node, error_code="ASYNC912") - self.uncheckpointed_statements.remove(s) - - if getattr(node, "asynchronous", None): - self.uncheckpointed_statements = set() + self.uncheckpointed_statements.add( + ArtificialStatement("with", line, column) + ) + self.node_dict[node] = res + self.has_checkpoint_stack.append(False) + else: + self.has_checkpoint_stack.append(True) + + def leave_With(self, original_node: cst.With, updated_node: cst.With): + # ASYNC100 + if not self.has_checkpoint_stack.pop(): + autofix = len(updated_node.items) == 1 + for res in self.node_dict[original_node]: + # bypass 910 & 911's should_autofix logic, which excludes asyncio + # (TODO: and uses self.noautofix ... which I don't remember what it's for) + autofix &= self.error( + res.node, res.base, res.function, error_code="ASYNC100" + ) and super().should_autofix(res.node, code="ASYNC100") + + if autofix: + return flatten_preserving_comments(updated_node) + # ASYNC912 + else: + pos = self.get_metadata( # pyright: ignore + PositionProvider, original_node + ).start # pyright: ignore + line: int = pos.line # pyright: ignore + column: int = pos.column # pyright: ignore + s = ArtificialStatement("with", line, column) + if s in self.uncheckpointed_statements: + self.error(original_node, error_code="ASYNC912") + self.uncheckpointed_statements.remove(s) + if getattr(original_node, "asynchronous", None): + self.checkpoint() + return updated_node # error if no checkpoint since earlier yield or function entry def leave_Yield( @@ -455,6 +500,9 @@ def leave_Yield( return updated_node self.has_yield = True + # Treat as a checkpoint for ASYNC100 + self.has_checkpoint_stack = [True] * len(self.has_checkpoint_stack) + if self.check_function_exit(original_node) and self.should_autofix( original_node ): @@ -629,7 +677,7 @@ def visit_While_body(self, node: cst.For | cst.While): # appropriate errors if the loop doesn't checkpoint if getattr(node, "asynchronous", None): - self.uncheckpointed_statements = set() + self.checkpoint() else: self.uncheckpointed_statements = {ARTIFICIAL_STATEMENT} @@ -675,7 +723,7 @@ def leave_While_body(self, node: cst.For | cst.While): # AsyncFor guarantees checkpoint on running out of iterable # so reset checkpoint state at end of loop. (but not state at break) if getattr(node, "asynchronous", None): - self.uncheckpointed_statements = set() + self.checkpoint() else: # enter orelse with worst case: # loop body might execute fully before entering orelse @@ -699,7 +747,7 @@ def leave_While_orelse(self, node: cst.For | cst.While): # if this is an infinite loop, with no break in it, don't raise # alarms about the state after it. if self.loop_state.infinite_loop and not self.loop_state.has_break: - self.uncheckpointed_statements = set() + self.checkpoint() else: # We may exit from: # orelse (covering: no body, body until continue, and all body) @@ -804,7 +852,7 @@ def visit_CompFor(self, node: cst.CompFor): # if async comprehension, checkpoint if node.asynchronous: - self.uncheckpointed_statements = set() + self.checkpoint() self.comp_unknown = False return False diff --git a/tests/autofix_files/async100.py b/tests/autofix_files/async100.py index db6733b5..bd3d2809 100644 --- a/tests/autofix_files/async100.py +++ b/tests/autofix_files/async100.py @@ -71,10 +71,11 @@ async def foo(): ... -# Seems like the inner context manager 'hides' the checkpoint. +# The outer cancelscope can get triggered in more complex cases, so +# to avoid false positives we don't raise a warning. async def does_contain_checkpoints(): - with trio.fail_after(1): # false-alarm ASYNC100 - with trio.CancelScope(): # or any other context manager + with trio.fail_after(1): + with trio.CancelScope(): await trio.sleep_forever() diff --git a/tests/autofix_files/async100_asyncio.py b/tests/autofix_files/async100_asyncio.py new file mode 100644 index 00000000..c3e53a8b --- /dev/null +++ b/tests/autofix_files/async100_asyncio.py @@ -0,0 +1,25 @@ +# TRIO_NO_ERROR +# ANYIO_NO_ERROR +# BASE_LIBRARY asyncio + +# mypy: disable-error-code=attr-defined +# AUTOFIX + +import asyncio +import asyncio.timeouts + + +async def foo(): + # py>=3.11 re-exports these in the main asyncio namespace + # error: 9, "asyncio", "timeout_at" + ... + # error: 9, "asyncio", "timeout" + ... + + # TODO + with asyncio.timeouts.timeout_at(10): + ... + with asyncio.timeouts.timeout_at(10): + ... + with asyncio.timeouts.timeout(10): + ... diff --git a/tests/autofix_files/async100_asyncio.py.diff b/tests/autofix_files/async100_asyncio.py.diff new file mode 100644 index 00000000..58e39f5e --- /dev/null +++ b/tests/autofix_files/async100_asyncio.py.diff @@ -0,0 +1,17 @@ +--- ++++ +@@ x,10 x,10 @@ + + async def foo(): + # py>=3.11 re-exports these in the main asyncio namespace +- with asyncio.timeout_at(10): # error: 9, "asyncio", "timeout_at" +- ... +- with asyncio.timeout(10): # error: 9, "asyncio", "timeout" +- ... ++ # error: 9, "asyncio", "timeout_at" ++ ... ++ # error: 9, "asyncio", "timeout" ++ ... + + + # TODO diff --git a/tests/autofix_files/async910.py b/tests/autofix_files/async910.py index 0415a7a4..95922ab7 100644 --- a/tests/autofix_files/async910.py +++ b/tests/autofix_files/async910.py @@ -1,4 +1,5 @@ # AUTOFIX +# ASYNCIO_NO_AUTOFIX # mypy: disable-error-code="unreachable" from __future__ import annotations diff --git a/tests/autofix_files/async911.py b/tests/autofix_files/async911.py index 720a9811..a91a322a 100644 --- a/tests/autofix_files/async911.py +++ b/tests/autofix_files/async911.py @@ -1,4 +1,5 @@ # AUTOFIX +# ASYNCIO_NO_AUTOFIX from typing import Any import pytest diff --git a/tests/autofix_files/async91x_autofix.py b/tests/autofix_files/async91x_autofix.py index b572f833..769e0e55 100644 --- a/tests/autofix_files/async91x_autofix.py +++ b/tests/autofix_files/async91x_autofix.py @@ -1,4 +1,6 @@ # AUTOFIX +# asyncio will raise the same errors, but does not have autofix available +# ASYNCIO_NO_AUTOFIX from __future__ import annotations """Docstring for file diff --git a/tests/autofix_files/noqa_testing.py b/tests/autofix_files/noqa_testing.py index 9bb4e456..b55942c4 100644 --- a/tests/autofix_files/noqa_testing.py +++ b/tests/autofix_files/noqa_testing.py @@ -1,4 +1,7 @@ +# TODO: When was this file added? Why? + # AUTOFIX +# ASYNCIO_NO_AUTOFIX # ARG --enable=ASYNC911 import trio diff --git a/tests/eval_files/async100.py b/tests/eval_files/async100.py index c9a00b53..226e34d1 100644 --- a/tests/eval_files/async100.py +++ b/tests/eval_files/async100.py @@ -71,10 +71,11 @@ async def foo(): ... -# Seems like the inner context manager 'hides' the checkpoint. +# The outer cancelscope can get triggered in more complex cases, so +# to avoid false positives we don't raise a warning. async def does_contain_checkpoints(): - with trio.fail_after(1): # false-alarm ASYNC100 - with trio.CancelScope(): # or any other context manager + with trio.fail_after(1): + with trio.CancelScope(): await trio.sleep_forever() diff --git a/tests/eval_files/async100_asyncio.py b/tests/eval_files/async100_asyncio.py index 9dd743c8..c853379d 100644 --- a/tests/eval_files/async100_asyncio.py +++ b/tests/eval_files/async100_asyncio.py @@ -1,7 +1,9 @@ # TRIO_NO_ERROR # ANYIO_NO_ERROR # BASE_LIBRARY asyncio -# ASYNCIO_NO_ERROR # TODO + +# mypy: disable-error-code=attr-defined +# AUTOFIX import asyncio import asyncio.timeouts @@ -9,12 +11,12 @@ async def foo(): # py>=3.11 re-exports these in the main asyncio namespace - with asyncio.timeout_at(10): # type: ignore[attr-defined] - ... - with asyncio.timeout_at(10): # type: ignore[attr-defined] + with asyncio.timeout_at(10): # error: 9, "asyncio", "timeout_at" ... - with asyncio.timeout(10): # type: ignore[attr-defined] + with asyncio.timeout(10): # error: 9, "asyncio", "timeout" ... + + # TODO with asyncio.timeouts.timeout_at(10): ... with asyncio.timeouts.timeout_at(10): diff --git a/tests/eval_files/async910.py b/tests/eval_files/async910.py index 11600888..68aee89f 100644 --- a/tests/eval_files/async910.py +++ b/tests/eval_files/async910.py @@ -1,4 +1,5 @@ # AUTOFIX +# ASYNCIO_NO_AUTOFIX # mypy: disable-error-code="unreachable" from __future__ import annotations diff --git a/tests/eval_files/async911.py b/tests/eval_files/async911.py index 8a19c525..b6d256de 100644 --- a/tests/eval_files/async911.py +++ b/tests/eval_files/async911.py @@ -1,4 +1,5 @@ # AUTOFIX +# ASYNCIO_NO_AUTOFIX from typing import Any import pytest diff --git a/tests/eval_files/async912.py b/tests/eval_files/async912.py index 8b99aafb..89361e12 100644 --- a/tests/eval_files/async912.py +++ b/tests/eval_files/async912.py @@ -1,30 +1,67 @@ # ASYNCIO_NO_ERROR +# ARG --enable=ASYNC100,ASYNC912 +# asyncio is tested in async912_asyncio. Cancelscopes in anyio are named the same +# as in trio, so they're also tested with this file. + +# ASYNC100 has autofixes, but ASYNC912 does not. This leaves us with the option +# of not testing both in the same file, or running with NOAUTOFIX. +# NOAUTOFIX + import trio +def bar() -> bool: + return False + + async def foo(): - with trio.move_on_after(0.1): # error: 4 + # trivial cases where there is absolutely no `await` only triggers ASYNC100 + with trio.move_on_after(0.1): # ASYNC100: 9, "trio", "move_on_after" ... - with trio.move_on_at(0.1): # error: 4 + with trio.move_on_at(0.1): # ASYNC100: 9, "trio", "move_on_at" ... - with trio.fail_after(0.1): # error: 4 + with trio.fail_after(0.1): # ASYNC100: 9, "trio", "fail_after" ... - with trio.fail_at(0.1): # error: 4 + with trio.fail_at(0.1): # ASYNC100: 9, "trio", "fail_at" ... - with trio.CancelScope(0.1): # error: 4 + with trio.CancelScope(0.1): # ASYNC100: 9, "trio", "CancelScope" ... + with trio.move_on_after(0.1): # ASYNC912: 4 + if bar(): + await trio.lowlevel.checkpoint() + with trio.move_on_at(0.1): # ASYNC912: 4 + while bar(): + await trio.lowlevel.checkpoint() + with trio.fail_after(0.1): # ASYNC912: 4 + try: + await trio.lowlevel.checkpoint() + except: + ... + with trio.fail_at(0.1): # ASYNC912: 4 + if bar(): + await trio.lowlevel.checkpoint() + with trio.CancelScope(0.1): # ASYNC912: 4 + if bar(): + await trio.lowlevel.checkpoint() + # ASYNC912 generally shares the same logic as other 91x codes, check respective + # eval files for more comprehensive tests. + + # check we don't trigger on all context managers with open(""): ... with trio.move_on_after(0.1): await trio.lowlevel.checkpoint() - with trio.move_on_after(0.1): # error: 4 - with trio.move_on_after(0.1): # error: 8 - ... + with trio.move_on_after(0.1): # ASYNC912: 4 + with trio.move_on_after(0.1): # ASYNC912: 8 + if bar(): + await trio.lowlevel.checkpoint() - with trio.move_on_after(0.1): # TODO: should probably raise an error? + # We don't know which cancelscope will trigger first, so to avoid false + # positives on tricky-but-valid cases we don't raise any error for the outer one. + with trio.move_on_after(0.1): with trio.move_on_after(0.1): await trio.lowlevel.checkpoint() @@ -40,28 +77,31 @@ async def foo(): # TODO: should probably raise the error at the call, rather than at the with statement # fmt: off - with ( # error: 4 + with ( # ASYNC912: 4 # a # b trio.move_on_after(0.1) # c ): - ... + if bar(): + await trio.lowlevel.checkpoint() - with ( # error: 4 + with ( # ASYNC912: 4 open(""), trio.move_on_at(5), open(""), ): - ... + if bar(): + await trio.lowlevel.checkpoint() # fmt: on # TODO: only raises one error currently, can make it raise 2(?) - with ( # error: 4 + with ( # ASYNC912: 4 trio.move_on_after(0.1), trio.fail_at(5), ): - ... + if bar(): + await trio.lowlevel.checkpoint() # TODO: issue #240 @@ -79,7 +119,7 @@ def condition() -> bool: async def livelocks_2(): - with trio.move_on_after(0.1): # error: 4 + with trio.move_on_after(0.1): # ASYNC912: 4 while condition(): try: await trio.sleep("1") # type: ignore @@ -98,7 +138,7 @@ async def livelocks_3(): # raises an error...? -with trio.move_on_after(10): # error: 0 +with trio.move_on_after(10): # ASYNC100: 5, "trio", "move_on_after" ... diff --git a/tests/eval_files/async912_asyncio.py b/tests/eval_files/async912_asyncio.py index 82b2547a..d41f1562 100644 --- a/tests/eval_files/async912_asyncio.py +++ b/tests/eval_files/async912_asyncio.py @@ -1,19 +1,35 @@ +# ARG --enable=ASYNC100,ASYNC912 # BASE_LIBRARY asyncio # ANYIO_NO_ERROR # TRIO_NO_ERROR +# ASYNC100 supports autofix, but ASYNC912 doesn't, so we must run with NOAUTOFIX +# NOAUTOFIX + # timeout[_at] added in py3.11 # mypy: disable-error-code=attr-defined import asyncio +def bar() -> bool: + return False + + async def foo(): - async with asyncio.timeout(10): # error: 4 + async with asyncio.timeout(10): # ASYNC100: 15, "asyncio", "timeout" + ... + async with asyncio.timeout_at(10): # ASYNC100: 15, "asyncio", "timeout_at" ... + async with asyncio.timeout(10): await foo() - async with asyncio.timeout_at(10): # error: 4 - ... async with asyncio.timeout_at(10): await foo() + + async with asyncio.timeout_at(10): # ASYNC912: 4 + if bar(): + await foo() + async with asyncio.timeout(10): # ASYNC912: 4 + if bar(): + await foo() diff --git a/tests/eval_files/async91x_autofix.py b/tests/eval_files/async91x_autofix.py index 58bb2a4c..9aa6a0d1 100644 --- a/tests/eval_files/async91x_autofix.py +++ b/tests/eval_files/async91x_autofix.py @@ -1,4 +1,6 @@ # AUTOFIX +# asyncio will raise the same errors, but does not have autofix available +# ASYNCIO_NO_AUTOFIX from __future__ import annotations """Docstring for file diff --git a/tests/eval_files/noqa_testing.py b/tests/eval_files/noqa_testing.py index 1c6ea8f5..1a1a3440 100644 --- a/tests/eval_files/noqa_testing.py +++ b/tests/eval_files/noqa_testing.py @@ -1,4 +1,7 @@ +# TODO: When was this file added? Why? + # AUTOFIX +# ASYNCIO_NO_AUTOFIX # ARG --enable=ASYNC911 import trio diff --git a/tests/test_flake8_async.py b/tests/test_flake8_async.py index 81a77c5a..3fa78582 100644 --- a/tests/test_flake8_async.py +++ b/tests/test_flake8_async.py @@ -109,22 +109,35 @@ def check_autofix( plugin: Plugin, unfixed_code: str, generate_autofix: bool, + magic_markers: MagicMarkers, library: str = "trio", - base_library: str = "trio", ): + base_library = magic_markers.BASE_LIBRARY # the source code after it's been visited by current transformers visited_code = plugin.module.code - if "# AUTOFIX" not in unfixed_code: - # if the file is specifically marked with NOAUTOFIX, that means it has visitors - # that will autofix with --autofix, but the file explicitly doesn't want to check - # the result of doing that. THIS IS DANGEROUS - if "# NOAUTOFIX" in unfixed_code: - print(f"eval file {test} marked with dangerous marker NOAUTOFIX") - else: - assert unfixed_code == visited_code + # if the file is specifically marked with NOAUTOFIX, that means it has visitors + # that will autofix with --autofix, but the file explicitly doesn't want to check + # the result of doing that. THIS IS DANGEROUS + assert not (magic_markers.AUTOFIX and magic_markers.NOAUTOFIX) + if magic_markers.NOAUTOFIX: + assert "# AUTOFIX" not in unfixed_code + print(f"eval file {test} marked with dangerous marker NOAUTOFIX") + return + + if ( + # not marked for autofixing + not magic_markers.AUTOFIX + # file+library does not raise errors + or magic_markers.library_no_error(library) + # code raises errors on asyncio, but does not support autofixing for it + or (library == "asyncio" and magic_markers.ASYNCIO_NO_AUTOFIX) + ): + assert unfixed_code == visited_code return + # if AUTOFIX, and library_NO_ERROR, assert file content isn't changed + # the full generated source code, saved from a previous run if test not in autofix_files: autofix_files[test] = AUTOFIX_DIR / (test.lower() + ".py") @@ -196,9 +209,22 @@ class MagicMarkers: ANYIO_NO_ERROR: bool = False TRIO_NO_ERROR: bool = False ASYNCIO_NO_ERROR: bool = False + + AUTOFIX: bool = False + NOAUTOFIX: bool = False + + # File should not get modified when running with asyncio+autofix + ASYNCIO_NO_AUTOFIX: bool = False # eval file is written using this library, so no substitution is required BASE_LIBRARY: str = "trio" + def library_no_error(self, library: str) -> bool: + return { + "anyio": self.ANYIO_NO_ERROR, + "asyncio": self.ASYNCIO_NO_ERROR, + "trio": self.TRIO_NO_ERROR, + }[library] + def find_magic_markers( content: str, @@ -300,15 +326,14 @@ def test_eval( lib in message for lib in ("anyio", "asyncio", "trio") ) - # asyncio does not support autofix atm, so should not modify content - if autofix and not noqa and library != "asyncio": + if autofix and not noqa: check_autofix( test, plugin, content, generate_autofix, library=library, - base_library=magic_markers.BASE_LIBRARY, + magic_markers=magic_markers, ) else: # make sure content isn't modified From 231ba7445087ee82dd4a38438ebe845b25f74669 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sun, 5 May 2024 14:27:10 +0200 Subject: [PATCH 3/5] cleanup, documentation, and last few TODOs fixed. --- CHANGELOG.md | 2 + README.md | 2 +- docs/rules.rst | 1 + flake8_async/visitors/flake8asyncvisitor.py | 2 +- flake8_async/visitors/helpers.py | 35 +++++++++++++--- flake8_async/visitors/visitor91x.py | 24 ++++++++--- tests/autofix_files/async100_asyncio.py | 13 +++--- tests/autofix_files/async100_asyncio.py.diff | 14 +++++-- tests/autofix_files/async91x_autofix.py | 9 ++-- tests/autofix_files/async91x_autofix.py.diff | 12 +++++- tests/eval_files/async100_asyncio.py | 9 ++-- tests/eval_files/async912.py | 43 ++++++++++++-------- tests/eval_files/async912_asyncio.py | 24 +++++++++-- tests/eval_files/async91x_autofix.py | 10 ++--- tests/test_flake8_async.py | 7 ++-- 15 files changed, 141 insertions(+), 66 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d4a66f2..3d3b9f06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,8 @@ ## 24.5.1 - Add ASYNC912: no checkpoints in with statement are guaranteed to run. +- ASYNC100 now properly treats async for comprehensions as checkpoints. +- ASYNC100 now supports autofixing on asyncio. ## 24.4.1 - ASYNC91X fix internal error caused by multiple `try/except` incorrectly sharing state. diff --git a/README.md b/README.md index 071b61fd..98ed50ee 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ Note: 22X, 23X and 24X has not had asyncio-specific suggestions written. - **ASYNC910**: Exit or `return` from async function with no guaranteed checkpoint or exception since function definition. You might want to enable this on a codebase to make it easier to reason about checkpoints, and make the logic of ASYNC911 correct. - **ASYNC911**: Exit, `yield` or `return` from async iterable with no guaranteed checkpoint since possible function entry (yield or function definition) Checkpoints are `await`, `async for`, and `async with` (on one of enter/exit). -- **ASYNC912**: TODO: write +- **ASYNC912**: Timeout/Cancelscope has no awaits that are guaranteed to run. If the scope has no checkpoints at all, then `ASYNC100` will be raised instead. ### Removed Warnings - **TRIOxxx**: All error codes are now renamed ASYNCxxx diff --git a/docs/rules.rst b/docs/rules.rst index 007a5485..729a2fea 100644 --- a/docs/rules.rst +++ b/docs/rules.rst @@ -49,6 +49,7 @@ Optional rules disabled by default - **ASYNC910**: Exit or ``return`` from async function with no guaranteed checkpoint or exception since function definition. You might want to enable this on a codebase to make it easier to reason about checkpoints, and make the logic of ASYNC911 correct. - **ASYNC911**: Exit, ``yield`` or ``return`` from async iterable with no guaranteed checkpoint since possible function entry (yield or function definition) Checkpoints are ``await``, ``async for``, and ``async with`` (on one of enter/exit). +-- **ASYNC912**: A timeout/cancelscope has checkpoints, but they're not guaranteed to run. Similar to ASYNC100, but it does not warn on trivial cases where there is no checkpoint at all. It instead shares logic with ASYNC910 and ASYNC911 for parsing conditionals and branches. Removed rules ================ diff --git a/flake8_async/visitors/flake8asyncvisitor.py b/flake8_async/visitors/flake8asyncvisitor.py index 92b4310c..7bce4a9a 100644 --- a/flake8_async/visitors/flake8asyncvisitor.py +++ b/flake8_async/visitors/flake8asyncvisitor.py @@ -245,7 +245,7 @@ def error( return True def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool: - if code is None: + if code is None: # pragma: no cover assert len(self.error_codes) == 1 code = next(iter(self.error_codes)) # this does not currently need to check for `noqa`s, as error() does that diff --git a/flake8_async/visitors/helpers.py b/flake8_async/visitors/helpers.py index 352ec144..7b3147e3 100644 --- a/flake8_async/visitors/helpers.py +++ b/flake8_async/visitors/helpers.py @@ -51,6 +51,7 @@ def error_class_cst(error_class: type[T_CST]) -> type[T_CST]: def disabled_by_default(error_class: type[T_EITHER]) -> type[T_EITHER]: + """Default-disables all error codes in a class.""" assert error_class.error_codes # type: ignore[attr-defined] default_disabled_error_codes.extend( error_class.error_codes # type: ignore[attr-defined] @@ -59,6 +60,7 @@ def disabled_by_default(error_class: type[T_EITHER]) -> type[T_EITHER]: def disable_codes_by_default(*codes: str) -> None: + """Default-disables only specified codes.""" default_disabled_error_codes.extend(codes) @@ -325,14 +327,19 @@ def with_has_call( node: cst.With, *names: str, base: Iterable[str] = ("trio", "anyio") ) -> list[AttributeCall]: if isinstance(base, str): - base = (base,) + base = (base,) # pragma: no cover + + for b in base: + if b.count(".") > 1: # pragma: no cover + raise NotImplementedError("Does not support 3-module bases atm.") + res_list: list[AttributeCall] = [] for item in node.items: if res := m.extract( item.item, m.Call( func=m.Attribute( - value=m.SaveMatchedNode(m.Name(), name="library"), + value=m.SaveMatchedNode(m.Name() | m.Attribute(), name="library"), attr=m.SaveMatchedNode( oneof_names(*names), name="function", @@ -341,12 +348,30 @@ def with_has_call( ), ): assert isinstance(item.item, cst.Call) - assert isinstance(res["library"], cst.Name) + assert isinstance(res["library"], (cst.Name, cst.Attribute)) assert isinstance(res["function"], cst.Name) - if res["library"].value not in base: + library_node = res["library"] + for library_str in base: + if ( + isinstance(library_node, cst.Name) + and library_str == library_node.value + ): + break + if ( + isinstance(library_node, cst.Attribute) + and isinstance(library_node.value, cst.Name) + and "." in library_str + ): + base_1, base_2 = library_str.split(".") + if ( + library_node.attr.value == base_2 + and library_node.value.value == base_1 + ): + break + else: continue res_list.append( - AttributeCall(item.item, res["library"].value, res["function"].value) + AttributeCall(item.item, library_str, res["function"].value) ) return res_list diff --git a/flake8_async/visitors/visitor91x.py b/flake8_async/visitors/visitor91x.py index 16286ddd..bbfd696b 100644 --- a/flake8_async/visitors/visitor91x.py +++ b/flake8_async/visitors/visitor91x.py @@ -33,11 +33,17 @@ from collections.abc import Mapping, Sequence -class ArtificialStatement(Statement): ... +class ArtificialStatement(Statement): + """Statement that should not trigger 910/911 on function exit. + + Used by loops and `with` statements. + """ # Statement injected at the start of loops to track missed checkpoints. ARTIFICIAL_STATEMENT = ArtificialStatement("artificial", -1) +# There's no particular reason why loops use a globally instanced statement, but +# `with` does not - mostly just an artifact of them being implemented at different times. def func_empty_body(node: cst.FunctionDef) -> bool: @@ -278,7 +284,7 @@ def __init__(self, *args: Any, **kwargs: Any): self.node_dict: dict[cst.With, list[AttributeCall]] = {} def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool: - if code is None: + if code is None: # pragma: no branch code = "ASYNC911" if self.has_yield else "ASYNC910" return ( @@ -451,7 +457,9 @@ def visit_With_body(self, node: cst.With): self.checkpoint() if res := ( with_has_call(node, *cancel_scope_names) - or with_has_call(node, "timeout", "timeout_at", base="asyncio") + or with_has_call( + node, "timeout", "timeout_at", base=("asyncio", "asyncio.timeouts") + ) ): pos = self.get_metadata(PositionProvider, node).start # pyright: ignore line: int = pos.line # pyright: ignore @@ -465,6 +473,8 @@ def visit_With_body(self, node: cst.With): self.has_checkpoint_stack.append(True) def leave_With(self, original_node: cst.With, updated_node: cst.With): + # Uses leave_With instead of leave_With_body because we need access to both + # original and updated node # ASYNC100 if not self.has_checkpoint_stack.pop(): autofix = len(updated_node.items) == 1 @@ -486,8 +496,9 @@ def leave_With(self, original_node: cst.With, updated_node: cst.With): column: int = pos.column # pyright: ignore s = ArtificialStatement("with", line, column) if s in self.uncheckpointed_statements: - self.error(original_node, error_code="ASYNC912") self.uncheckpointed_statements.remove(s) + for res in self.node_dict[original_node]: + self.error(res.node, error_code="ASYNC912") if getattr(original_node, "asynchronous", None): self.checkpoint() return updated_node @@ -500,7 +511,8 @@ def leave_Yield( return updated_node self.has_yield = True - # Treat as a checkpoint for ASYNC100 + # Treat as a checkpoint for ASYNC100, since the context we yield to + # may checkpoint. self.has_checkpoint_stack = [True] * len(self.has_checkpoint_stack) if self.check_function_exit(original_node) and self.should_autofix( @@ -747,7 +759,7 @@ def leave_While_orelse(self, node: cst.For | cst.While): # if this is an infinite loop, with no break in it, don't raise # alarms about the state after it. if self.loop_state.infinite_loop and not self.loop_state.has_break: - self.checkpoint() + self.uncheckpointed_statements = set() else: # We may exit from: # orelse (covering: no body, body until continue, and all body) diff --git a/tests/autofix_files/async100_asyncio.py b/tests/autofix_files/async100_asyncio.py index c3e53a8b..cfd0121e 100644 --- a/tests/autofix_files/async100_asyncio.py +++ b/tests/autofix_files/async100_asyncio.py @@ -2,6 +2,7 @@ # ANYIO_NO_ERROR # BASE_LIBRARY asyncio +# timeout[_at] re-exported in the main asyncio namespace in py3.11 # mypy: disable-error-code=attr-defined # AUTOFIX @@ -10,16 +11,12 @@ async def foo(): - # py>=3.11 re-exports these in the main asyncio namespace # error: 9, "asyncio", "timeout_at" ... # error: 9, "asyncio", "timeout" ... - # TODO - with asyncio.timeouts.timeout_at(10): - ... - with asyncio.timeouts.timeout_at(10): - ... - with asyncio.timeouts.timeout(10): - ... + # error: 9, "asyncio.timeouts", "timeout_at" + ... + # error: 9, "asyncio.timeouts", "timeout" + ... diff --git a/tests/autofix_files/async100_asyncio.py.diff b/tests/autofix_files/async100_asyncio.py.diff index 58e39f5e..f083238a 100644 --- a/tests/autofix_files/async100_asyncio.py.diff +++ b/tests/autofix_files/async100_asyncio.py.diff @@ -1,9 +1,9 @@ --- +++ -@@ x,10 x,10 @@ +@@ x,12 x,12 @@ + async def foo(): - # py>=3.11 re-exports these in the main asyncio namespace - with asyncio.timeout_at(10): # error: 9, "asyncio", "timeout_at" - ... - with asyncio.timeout(10): # error: 9, "asyncio", "timeout" @@ -13,5 +13,11 @@ + # error: 9, "asyncio", "timeout" + ... - - # TODO +- with asyncio.timeouts.timeout_at(10): # error: 9, "asyncio.timeouts", "timeout_at" +- ... +- with asyncio.timeouts.timeout(10): # error: 9, "asyncio.timeouts", "timeout" +- ... ++ # error: 9, "asyncio.timeouts", "timeout_at" ++ ... ++ # error: 9, "asyncio.timeouts", "timeout" ++ ... diff --git a/tests/autofix_files/async91x_autofix.py b/tests/autofix_files/async91x_autofix.py index 769e0e55..35fe6ff2 100644 --- a/tests/autofix_files/async91x_autofix.py +++ b/tests/autofix_files/async91x_autofix.py @@ -11,7 +11,6 @@ # ARG --enable=ASYNC910,ASYNC911 from typing import Any - import trio @@ -135,11 +134,13 @@ async def livelocks(): ... -# this will autofix 910 by adding a checkpoint outside the loop +# this will autofix 910 by adding a checkpoint outside the loop, which doesn't actually +# help, and the method still isn't guaranteed to checkpoint in case bar() always returns +# True. async def no_checkpoint(): # ASYNC910: 0, "exit", Statement("function definition", lineno) while bar(): try: - await trio.sleep("1") # type: ignore[arg-type] - except ValueError: + await foo("1") # type: ignore[call-arg] + except TypeError: ... await trio.lowlevel.checkpoint() diff --git a/tests/autofix_files/async91x_autofix.py.diff b/tests/autofix_files/async91x_autofix.py.diff index 5fa6c660..e6e6625d 100644 --- a/tests/autofix_files/async91x_autofix.py.diff +++ b/tests/autofix_files/async91x_autofix.py.diff @@ -1,5 +1,13 @@ --- +++ +@@ x,6 x,7 @@ + # ARG --enable=ASYNC910,ASYNC911 + + from typing import Any ++import trio + + + def bar() -> Any: ... @@ x,30 x,38 @@ async def foo1(): # ASYNC910: 0, "exit", Statement("function definition", lineno) @@ -71,7 +79,7 @@ async def bar(): @@ x,3 x,4 @@ - await trio.sleep("1") # type: ignore[arg-type] - except ValueError: + await foo("1") # type: ignore[call-arg] + except TypeError: ... + await trio.lowlevel.checkpoint() diff --git a/tests/eval_files/async100_asyncio.py b/tests/eval_files/async100_asyncio.py index c853379d..494803ab 100644 --- a/tests/eval_files/async100_asyncio.py +++ b/tests/eval_files/async100_asyncio.py @@ -2,6 +2,7 @@ # ANYIO_NO_ERROR # BASE_LIBRARY asyncio +# timeout[_at] re-exported in the main asyncio namespace in py3.11 # mypy: disable-error-code=attr-defined # AUTOFIX @@ -10,16 +11,12 @@ async def foo(): - # py>=3.11 re-exports these in the main asyncio namespace with asyncio.timeout_at(10): # error: 9, "asyncio", "timeout_at" ... with asyncio.timeout(10): # error: 9, "asyncio", "timeout" ... - # TODO - with asyncio.timeouts.timeout_at(10): + with asyncio.timeouts.timeout_at(10): # error: 9, "asyncio.timeouts", "timeout_at" ... - with asyncio.timeouts.timeout_at(10): - ... - with asyncio.timeouts.timeout(10): + with asyncio.timeouts.timeout(10): # error: 9, "asyncio.timeouts", "timeout" ... diff --git a/tests/eval_files/async912.py b/tests/eval_files/async912.py index 89361e12..3ebbc678 100644 --- a/tests/eval_files/async912.py +++ b/tests/eval_files/async912.py @@ -27,21 +27,21 @@ async def foo(): with trio.CancelScope(0.1): # ASYNC100: 9, "trio", "CancelScope" ... - with trio.move_on_after(0.1): # ASYNC912: 4 + with trio.move_on_after(0.1): # ASYNC912: 9 if bar(): await trio.lowlevel.checkpoint() - with trio.move_on_at(0.1): # ASYNC912: 4 + with trio.move_on_at(0.1): # ASYNC912: 9 while bar(): await trio.lowlevel.checkpoint() - with trio.fail_after(0.1): # ASYNC912: 4 + with trio.fail_after(0.1): # ASYNC912: 9 try: await trio.lowlevel.checkpoint() except: ... - with trio.fail_at(0.1): # ASYNC912: 4 + with trio.fail_at(0.1): # ASYNC912: 9 if bar(): await trio.lowlevel.checkpoint() - with trio.CancelScope(0.1): # ASYNC912: 4 + with trio.CancelScope(0.1): # ASYNC912: 9 if bar(): await trio.lowlevel.checkpoint() # ASYNC912 generally shares the same logic as other 91x codes, check respective @@ -54,8 +54,8 @@ async def foo(): with trio.move_on_after(0.1): await trio.lowlevel.checkpoint() - with trio.move_on_after(0.1): # ASYNC912: 4 - with trio.move_on_after(0.1): # ASYNC912: 8 + with trio.move_on_after(0.1): # ASYNC912: 9 + with trio.move_on_after(0.1): # ASYNC912: 13 if bar(): await trio.lowlevel.checkpoint() @@ -75,30 +75,28 @@ async def foo(): await trio.lowlevel.checkpoint() await trio.lowlevel.checkpoint() - # TODO: should probably raise the error at the call, rather than at the with statement # fmt: off - with ( # ASYNC912: 4 + with ( # a # b - trio.move_on_after(0.1) + trio.move_on_after(0.1) # ASYNC912: 12 # c ): if bar(): await trio.lowlevel.checkpoint() - with ( # ASYNC912: 4 + with ( open(""), - trio.move_on_at(5), + trio.move_on_at(5), # ASYNC912: 12 open(""), ): if bar(): await trio.lowlevel.checkpoint() # fmt: on - # TODO: only raises one error currently, can make it raise 2(?) - with ( # ASYNC912: 4 - trio.move_on_after(0.1), - trio.fail_at(5), + with ( + trio.move_on_after(0.1), # ASYNC912: 8 + trio.fail_at(5), # ASYNC912: 8 ): if bar(): await trio.lowlevel.checkpoint() @@ -119,7 +117,7 @@ def condition() -> bool: async def livelocks_2(): - with trio.move_on_after(0.1): # ASYNC912: 4 + with trio.move_on_after(0.1): # ASYNC912: 9 while condition(): try: await trio.sleep("1") # type: ignore @@ -146,3 +144,14 @@ async def livelocks_3(): def sync_func(): with trio.move_on_after(10): ... + + +async def check_yield_logic(): + # Does not raise any of async100 or async912, as the yield is treated + # as a checkpoint because the parent context may checkpoint. + with trio.move_on_after(1): + yield + with trio.move_on_after(1): + if bar(): + await trio.lowlevel.checkpoint() + yield diff --git a/tests/eval_files/async912_asyncio.py b/tests/eval_files/async912_asyncio.py index d41f1562..1e840b44 100644 --- a/tests/eval_files/async912_asyncio.py +++ b/tests/eval_files/async912_asyncio.py @@ -6,7 +6,7 @@ # ASYNC100 supports autofix, but ASYNC912 doesn't, so we must run with NOAUTOFIX # NOAUTOFIX -# timeout[_at] added in py3.11 +# timeout[_at] re-exported in the main asyncio namespace in py3.11 # mypy: disable-error-code=attr-defined import asyncio @@ -17,19 +17,37 @@ def bar() -> bool: async def foo(): + # async100 async with asyncio.timeout(10): # ASYNC100: 15, "asyncio", "timeout" ... async with asyncio.timeout_at(10): # ASYNC100: 15, "asyncio", "timeout_at" ... + async with asyncio.timeouts.timeout( + 10 + ): # ASYNC100: 15, "asyncio.timeouts", "timeout" + ... + async with asyncio.timeouts.timeout_at( + 10 + ): # ASYNC100: 15, "asyncio.timeouts", "timeout_at" + ... + # no errors async with asyncio.timeout(10): await foo() async with asyncio.timeout_at(10): await foo() - async with asyncio.timeout_at(10): # ASYNC912: 4 + # async912 + async with asyncio.timeout_at(10): # ASYNC912: 15 + if bar(): + await foo() + async with asyncio.timeout(10): # ASYNC912: 15 + if bar(): + await foo() + + async with asyncio.timeouts.timeout(10): # ASYNC912: 15 if bar(): await foo() - async with asyncio.timeout(10): # ASYNC912: 4 + async with asyncio.timeouts.timeout_at(10): # ASYNC912: 15 if bar(): await foo() diff --git a/tests/eval_files/async91x_autofix.py b/tests/eval_files/async91x_autofix.py index 9aa6a0d1..7ce0a359 100644 --- a/tests/eval_files/async91x_autofix.py +++ b/tests/eval_files/async91x_autofix.py @@ -12,8 +12,6 @@ from typing import Any -import trio - def bar() -> Any: ... @@ -121,10 +119,12 @@ async def livelocks(): ... -# this will autofix 910 by adding a checkpoint outside the loop +# this will autofix 910 by adding a checkpoint outside the loop, which doesn't actually +# help, and the method still isn't guaranteed to checkpoint in case bar() always returns +# True. async def no_checkpoint(): # ASYNC910: 0, "exit", Statement("function definition", lineno) while bar(): try: - await trio.sleep("1") # type: ignore[arg-type] - except ValueError: + await foo("1") # type: ignore[call-arg] + except TypeError: ... diff --git a/tests/test_flake8_async.py b/tests/test_flake8_async.py index 3fa78582..cad9d2a2 100644 --- a/tests/test_flake8_async.py +++ b/tests/test_flake8_async.py @@ -121,7 +121,6 @@ def check_autofix( # the result of doing that. THIS IS DANGEROUS assert not (magic_markers.AUTOFIX and magic_markers.NOAUTOFIX) if magic_markers.NOAUTOFIX: - assert "# AUTOFIX" not in unfixed_code print(f"eval file {test} marked with dangerous marker NOAUTOFIX") return @@ -133,11 +132,11 @@ def check_autofix( # code raises errors on asyncio, but does not support autofixing for it or (library == "asyncio" and magic_markers.ASYNCIO_NO_AUTOFIX) ): - assert unfixed_code == visited_code + assert ( + unfixed_code == visited_code + ), "Code changed after visiting, but magic markers say it shouldn't change." return - # if AUTOFIX, and library_NO_ERROR, assert file content isn't changed - # the full generated source code, saved from a previous run if test not in autofix_files: autofix_files[test] = AUTOFIX_DIR / (test.lower() + ".py") From e7c5ccdfd6ab06a07fb11b0b00bff7431b0d0ab6 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Sun, 5 May 2024 14:30:51 +0200 Subject: [PATCH 4/5] fix silly formatting --- tests/eval_files/async912_asyncio.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/eval_files/async912_asyncio.py b/tests/eval_files/async912_asyncio.py index 1e840b44..5a4227eb 100644 --- a/tests/eval_files/async912_asyncio.py +++ b/tests/eval_files/async912_asyncio.py @@ -22,13 +22,13 @@ async def foo(): ... async with asyncio.timeout_at(10): # ASYNC100: 15, "asyncio", "timeout_at" ... - async with asyncio.timeouts.timeout( + async with asyncio.timeouts.timeout( # ASYNC100: 15, "asyncio.timeouts", "timeout" 10 - ): # ASYNC100: 15, "asyncio.timeouts", "timeout" + ): ... - async with asyncio.timeouts.timeout_at( + async with asyncio.timeouts.timeout_at( # ASYNC100: 15, "asyncio.timeouts", "timeout_at" 10 - ): # ASYNC100: 15, "asyncio.timeouts", "timeout_at" + ): ... # no errors From df403c3a98943be5bcfb1220557ec690e8a478d7 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 10 May 2024 14:00:33 +0200 Subject: [PATCH 5/5] use matcher for matching instead of custom logic --- flake8_async/visitors/helpers.py | 87 ++++++++++++++++------------ tests/eval_files/async912.py | 29 +++++++++- tests/eval_files/async912_asyncio.py | 23 ++++++++ 3 files changed, 101 insertions(+), 38 deletions(-) diff --git a/flake8_async/visitors/helpers.py b/flake8_async/visitors/helpers.py index 7b3147e3..d33e0992 100644 --- a/flake8_async/visitors/helpers.py +++ b/flake8_async/visitors/helpers.py @@ -323,55 +323,68 @@ class AttributeCall(NamedTuple): function: str +# the custom __or__ in libcst breaks pyright type checking. It's possible to use +# `Union` as a workaround ... except pyupgrade will automatically replace that. +# So we have to resort to specifying one of the base classes. +# See https://github.com/Instagram/LibCST/issues/1143 +def build_cst_matcher(attr: str) -> m.BaseExpression: + """Build a cst matcher structure with attributes&names matching a string `a.b.c`.""" + if "." not in attr: + return m.Name(value=attr) + body, tail = attr.rsplit(".") + return m.Attribute(value=build_cst_matcher(body), attr=m.Name(value=tail)) + + +def identifier_to_string(attr: cst.Name | cst.Attribute) -> str: + if isinstance(attr, cst.Name): + return attr.value + assert isinstance(attr.value, (cst.Attribute, cst.Name)) + return identifier_to_string(attr.value) + "." + attr.attr.value + + def with_has_call( node: cst.With, *names: str, base: Iterable[str] = ("trio", "anyio") ) -> list[AttributeCall]: + """Check if a with statement has a matching call, returning a list with matches. + + `names` specify the names of functions to match, `base` specifies the + library/module(s) the function must be in. + The list elements in the return value are named tuples with the matched node, + base and function. + + Examples_ + + `with_has_call(node, "bar", base="foo")` matches foo.bar. + `with_has_call(node, "bar", "bee", base=("foo", "a.b.c")` matches + `foo.bar`, `foo.bee`, `a.b.c.bar`, and `a.b.c.bee`. + + """ if isinstance(base, str): base = (base,) # pragma: no cover - for b in base: - if b.count(".") > 1: # pragma: no cover - raise NotImplementedError("Does not support 3-module bases atm.") + # build matcher, using SaveMatchedNode to save the base and the function name. + matcher = m.Call( + func=m.Attribute( + value=m.SaveMatchedNode( + m.OneOf(*(build_cst_matcher(b) for b in base)), name="base" + ), + attr=m.SaveMatchedNode( + oneof_names(*names), + name="function", + ), + ) + ) res_list: list[AttributeCall] = [] for item in node.items: - if res := m.extract( - item.item, - m.Call( - func=m.Attribute( - value=m.SaveMatchedNode(m.Name() | m.Attribute(), name="library"), - attr=m.SaveMatchedNode( - oneof_names(*names), - name="function", - ), - ) - ), - ): + if res := m.extract(item.item, matcher): assert isinstance(item.item, cst.Call) - assert isinstance(res["library"], (cst.Name, cst.Attribute)) + assert isinstance(res["base"], (cst.Name, cst.Attribute)) assert isinstance(res["function"], cst.Name) - library_node = res["library"] - for library_str in base: - if ( - isinstance(library_node, cst.Name) - and library_str == library_node.value - ): - break - if ( - isinstance(library_node, cst.Attribute) - and isinstance(library_node.value, cst.Name) - and "." in library_str - ): - base_1, base_2 = library_str.split(".") - if ( - library_node.attr.value == base_2 - and library_node.value.value == base_1 - ): - break - else: - continue res_list.append( - AttributeCall(item.item, library_str, res["function"].value) + AttributeCall( + item.item, identifier_to_string(res["base"]), res["function"].value + ) ) return res_list diff --git a/tests/eval_files/async912.py b/tests/eval_files/async912.py index 3ebbc678..c2abf045 100644 --- a/tests/eval_files/async912.py +++ b/tests/eval_files/async912.py @@ -7,6 +7,8 @@ # of not testing both in the same file, or running with NOAUTOFIX. # NOAUTOFIX +from typing import TypeVar + import trio @@ -27,6 +29,7 @@ async def foo(): with trio.CancelScope(0.1): # ASYNC100: 9, "trio", "CancelScope" ... + # conditional cases trigger ASYNC912 with trio.move_on_after(0.1): # ASYNC912: 9 if bar(): await trio.lowlevel.checkpoint() @@ -51,16 +54,23 @@ async def foo(): with open(""): ... + # don't error with guaranteed checkpoint with trio.move_on_after(0.1): await trio.lowlevel.checkpoint() + with trio.move_on_after(0.1): + if bar(): + await trio.lowlevel.checkpoint() + else: + await trio.lowlevel.checkpoint() + # both scopes error in nested cases with trio.move_on_after(0.1): # ASYNC912: 9 with trio.move_on_after(0.1): # ASYNC912: 13 if bar(): await trio.lowlevel.checkpoint() # We don't know which cancelscope will trigger first, so to avoid false - # positives on tricky-but-valid cases we don't raise any error for the outer one. + # alarms on tricky-but-valid cases we don't raise any error for the outer one. with trio.move_on_after(0.1): with trio.move_on_after(0.1): await trio.lowlevel.checkpoint() @@ -75,6 +85,7 @@ async def foo(): await trio.lowlevel.checkpoint() await trio.lowlevel.checkpoint() + # check correct line gives error # fmt: off with ( # a @@ -94,6 +105,7 @@ async def foo(): await trio.lowlevel.checkpoint() # fmt: on + # error on each call with multiple matching calls in the same with with ( trio.move_on_after(0.1), # ASYNC912: 8 trio.fail_at(5), # ASYNC912: 8 @@ -101,6 +113,21 @@ async def foo(): if bar(): await trio.lowlevel.checkpoint() + # wrapped calls do not raise errors + T = TypeVar("T") + + def customWrapper(a: T) -> T: + return a + + with customWrapper(trio.fail_at(10)): + ... + with (res := trio.fail_at(10)): + ... + # but saving with `as` does + with trio.fail_at(10) as res: # ASYNC912: 9 + if bar(): + await trio.lowlevel.checkpoint() + # TODO: issue #240 async def livelocks(): diff --git a/tests/eval_files/async912_asyncio.py b/tests/eval_files/async912_asyncio.py index 5a4227eb..ef9200bf 100644 --- a/tests/eval_files/async912_asyncio.py +++ b/tests/eval_files/async912_asyncio.py @@ -11,11 +11,16 @@ import asyncio +from typing import Any + def bar() -> bool: return False +def customWrapper(a: object) -> object: ... + + async def foo(): # async100 async with asyncio.timeout(10): # ASYNC100: 15, "asyncio", "timeout" @@ -51,3 +56,21 @@ async def foo(): async with asyncio.timeouts.timeout_at(10): # ASYNC912: 15 if bar(): await foo() + + # double check that helper methods used by visitor don't trigger erroneously + timeouts: Any + timeout_at: Any + async with asyncio.timeout_at.timeouts(10): + ... + async with timeouts.asyncio.timeout_at(10): + ... + async with timeouts.timeout_at.asyncio(10): + ... + async with timeout_at.asyncio.timeouts(10): + ... + async with timeout_at.timeouts.asyncio(10): + ... + async with foo.timeout(10): + ... + async with asyncio.timeouts(10): + ...