From 5a971f7d22a4185ee7b3ea6b04a4205a91ed02fa Mon Sep 17 00:00:00 2001 From: Derek Wan Date: Sun, 15 Sep 2024 11:14:39 +0900 Subject: [PATCH 1/5] specific try_await error --- src/utilities/asyncio.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/utilities/asyncio.py b/src/utilities/asyncio.py index d2e7641db..69682e4a3 100644 --- a/src/utilities/asyncio.py +++ b/src/utilities/asyncio.py @@ -1,9 +1,12 @@ from __future__ import annotations from functools import wraps +from re import search from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, overload from utilities.functions import ensure_not_none +from utilities.iterables import OneError, one +from utilities.text import EnsureStrError, ensure_str if TYPE_CHECKING: from collections.abc import AsyncGenerator, Callable, Coroutine @@ -251,8 +254,23 @@ async def try_await(obj: MaybeAwaitable[_T], /) -> _T: """Try await a value from an object.""" try: return await cast(Awaitable[_T], obj) - except TypeError: - return cast(_T, obj) + except TypeError as error: + try: + text = ensure_str(one(error.args)) + except (EnsureStrError, OneError): + raise error from None + if search("object .* can't be used in 'await' expression", text): + return cast(_T, obj) + raise + + +@dataclass(kw_only=True) +class TryAwaitError(Exception): + obj: Any + + @override + def __str__(self) -> str: + return f"Object must be awaitable; got {self.obj}" __all__ = [ @@ -271,4 +289,5 @@ async def try_await(obj: MaybeAwaitable[_T], /) -> _T: "to_set", "to_sorted", "try_await", + 'TryAwaitError' ] From 482f31fc6e20176ceedf6ad81f3386a6f0c38665 Mon Sep 17 00:00:00 2001 From: Derek Wan Date: Sun, 15 Sep 2024 13:12:09 +0900 Subject: [PATCH 2/5] more --- requirements.txt | 8 ++++---- requirements/jupyter.txt | 4 ++-- requirements/slack-sdk.txt | 2 +- requirements/streamlit.txt | 2 +- src/tests/test_asyncio.py | 40 ++++++++++++++++++++++++-------------- src/utilities/__init__.py | 2 +- src/utilities/asyncio.py | 10 ---------- 7 files changed, 34 insertions(+), 34 deletions(-) diff --git a/requirements.txt b/requirements.txt index 5d804af43..5df356853 100644 --- a/requirements.txt +++ b/requirements.txt @@ -128,7 +128,7 @@ hyperlink==21.0.0 # via hatch hypothesis==6.112.1 # via dycw-utilities (pyproject.toml) -idna==3.8 +idna==3.9 # via # anyio # httpx @@ -254,7 +254,7 @@ pillow==10.4.0 # streamlit pip==24.2 # via pypiserver -platformdirs==4.3.2 +platformdirs==4.3.3 # via # hatch # virtualenv @@ -427,7 +427,7 @@ urllib3==2.2.3 # via requests userpath==1.9.2 # via hatch -uv==0.4.9 +uv==0.4.10 # via hatch vegafusion==1.6.9 # via dycw-utilities (pyproject.toml) @@ -447,7 +447,7 @@ wrapt==1.16.0 # via deprecated yarl==1.11.1 # via aiohttp -zipp==3.20.1 +zipp==3.20.2 # via importlib-metadata zstandard==0.23.0 # via hatch diff --git a/requirements/jupyter.txt b/requirements/jupyter.txt index e256acd4b..e7554ba5b 100644 --- a/requirements/jupyter.txt +++ b/requirements/jupyter.txt @@ -60,7 +60,7 @@ httpx==0.27.2 # via jupyterlab hypothesis==6.112.1 # via dycw-utilities (pyproject.toml) -idna==3.8 +idna==3.9 # via # anyio # httpx @@ -169,7 +169,7 @@ parso==0.8.4 # via jedi pexpect==4.9.0 # via ipython -platformdirs==4.3.2 +platformdirs==4.3.3 # via jupyter-core pluggy==1.5.0 # via pytest diff --git a/requirements/slack-sdk.txt b/requirements/slack-sdk.txt index 34d79da5f..5263f8e26 100644 --- a/requirements/slack-sdk.txt +++ b/requirements/slack-sdk.txt @@ -18,7 +18,7 @@ frozenlist==1.4.1 # aiosignal hypothesis==6.112.1 # via dycw-utilities (pyproject.toml) -idna==3.8 +idna==3.9 # via yarl iniconfig==2.0.0 # via pytest diff --git a/requirements/streamlit.txt b/requirements/streamlit.txt index 3d46441dd..518c63f8d 100644 --- a/requirements/streamlit.txt +++ b/requirements/streamlit.txt @@ -25,7 +25,7 @@ gitpython==3.1.43 # via streamlit hypothesis==6.112.1 # via dycw-utilities (pyproject.toml) -idna==3.8 +idna==3.9 # via requests iniconfig==2.0.0 # via pytest diff --git a/src/tests/test_asyncio.py b/src/tests/test_asyncio.py index 6818a081c..2d8af9575 100644 --- a/src/tests/test_asyncio.py +++ b/src/tests/test_asyncio.py @@ -442,24 +442,14 @@ async def test_reverse( class TestTryAwait: - async def test_maybe_awaitable(self) -> None: - from utilities.random import SYSTEM_RANDOM - - if SYSTEM_RANDOM.random() <= 0.5: - - async def func(*, value: bool) -> bool: # pyright: ignore[reportRedeclaration] - await sleep(0.01) - return not value - - else: - - def func(*, value: bool) -> bool: - return not value + async def test_sync(self) -> None: + def func(*, value: bool) -> bool: + return not value result = await try_await(func(value=True)) assert result is False - async def test_awaitable(self) -> None: + async def test_async(self) -> None: async def func(*, value: bool) -> bool: await sleep(0.01) return not value @@ -467,9 +457,29 @@ async def func(*, value: bool) -> bool: result = await try_await(func(value=True)) assert result is False - async def test_sync(self) -> None: + @mark.parametrize("cls", [param(ValueError), param(TypeError)], ids=str) + async def test_error_sync(self, *, cls: type[Exception]) -> None: def func(*, value: bool) -> bool: + if not value: + msg = f"Value must be True; got {value}" + raise cls(msg) + return not value + + result = await try_await(func(value=True)) + assert result is False + with raises(cls, match="Value must be True; got False"): + _ = await try_await(func(value=False)) + + @mark.parametrize("cls", [param(ValueError), param(TypeError)], ids=str) + async def test_error_async(self, *, cls: type[Exception]) -> None: + async def func(*, value: bool) -> bool: + if not value: + msg = f"Value must be True; got {value}" + raise cls(msg) + await sleep(0.01) return not value result = await try_await(func(value=True)) assert result is False + with raises(cls, match="Value must be True; got False"): + _ = await try_await(func(value=False)) diff --git a/src/utilities/__init__.py b/src/utilities/__init__.py index c76758c5a..3481fa31b 100644 --- a/src/utilities/__init__.py +++ b/src/utilities/__init__.py @@ -1,3 +1,3 @@ from __future__ import annotations -__version__ = "0.54.2" +__version__ = "0.54.3" diff --git a/src/utilities/asyncio.py b/src/utilities/asyncio.py index 69682e4a3..f74749ab2 100644 --- a/src/utilities/asyncio.py +++ b/src/utilities/asyncio.py @@ -264,15 +264,6 @@ async def try_await(obj: MaybeAwaitable[_T], /) -> _T: raise -@dataclass(kw_only=True) -class TryAwaitError(Exception): - obj: Any - - @override - def __str__(self) -> str: - return f"Object must be awaitable; got {self.obj}" - - __all__ = [ "Coroutine1", "MaybeAwaitable", @@ -289,5 +280,4 @@ def __str__(self) -> str: "to_set", "to_sorted", "try_await", - 'TryAwaitError' ] From 5d5dbd911374878d5168c7be8104167c55241c19 Mon Sep 17 00:00:00 2001 From: Derek Wan Date: Sun, 15 Sep 2024 13:13:05 +0900 Subject: [PATCH 3/5] more --- src/tests/test_asyncio.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tests/test_asyncio.py b/src/tests/test_asyncio.py index 2d8af9575..4a3885c8b 100644 --- a/src/tests/test_asyncio.py +++ b/src/tests/test_asyncio.py @@ -458,7 +458,7 @@ async def func(*, value: bool) -> bool: assert result is False @mark.parametrize("cls", [param(ValueError), param(TypeError)], ids=str) - async def test_error_sync(self, *, cls: type[Exception]) -> None: + async def test_error_std_msg_sync(self, *, cls: type[Exception]) -> None: def func(*, value: bool) -> bool: if not value: msg = f"Value must be True; got {value}" @@ -471,7 +471,7 @@ def func(*, value: bool) -> bool: _ = await try_await(func(value=False)) @mark.parametrize("cls", [param(ValueError), param(TypeError)], ids=str) - async def test_error_async(self, *, cls: type[Exception]) -> None: + async def test_error_std_msg_async(self, *, cls: type[Exception]) -> None: async def func(*, value: bool) -> bool: if not value: msg = f"Value must be True; got {value}" From 9dceaa85766941109f07e2c2de4e5f0ad09928a5 Mon Sep 17 00:00:00 2001 From: Derek Wan Date: Sun, 15 Sep 2024 14:56:44 +0900 Subject: [PATCH 4/5] more --- src/tests/test_asyncio.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/tests/test_asyncio.py b/src/tests/test_asyncio.py index 4a3885c8b..40411eea9 100644 --- a/src/tests/test_asyncio.py +++ b/src/tests/test_asyncio.py @@ -465,8 +465,6 @@ def func(*, value: bool) -> bool: raise cls(msg) return not value - result = await try_await(func(value=True)) - assert result is False with raises(cls, match="Value must be True; got False"): _ = await try_await(func(value=False)) @@ -479,7 +477,26 @@ async def func(*, value: bool) -> bool: await sleep(0.01) return not value - result = await try_await(func(value=True)) - assert result is False with raises(cls, match="Value must be True; got False"): _ = await try_await(func(value=False)) + + @mark.parametrize("cls", [param(ValueError), param(TypeError)], ids=str) + async def test_error_non_std_msg_sync(self, *, cls: type[Exception]) -> None: + def func(*, value: bool) -> bool: + if not value: + raise cls(value) + return not value + + with raises(cls, match="False"): + _ = await try_await(func(value=False)) + + @mark.parametrize("cls", [param(ValueError), param(TypeError)], ids=str) + async def test_error_non_std_msg_async(self, *, cls: type[Exception]) -> None: + async def func(*, value: bool) -> bool: + if not value: + raise cls(value) + await sleep(0.01) + return not value + + with raises(cls, match="False"): + _ = await try_await(func(value=False)) From cc10c62c37b3fa52725a8c1986d9f704ef1c93cd Mon Sep 17 00:00:00 2001 From: Derek Wan Date: Sun, 15 Sep 2024 14:58:04 +0900 Subject: [PATCH 5/5] more --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7a3b34840..e3b9bcbde 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ default_language_version: repos: # fixers - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.4 + rev: v0.6.5 hooks: - id: ruff args: [--fix]