Skip to content

Commit

Permalink
Have try_await suppress only the 'non-awaitable` error (#732)
Browse files Browse the repository at this point in the history
  • Loading branch information
dycw authored Sep 15, 2024
1 parent aaf4ef7 commit 70c6ab6
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_language_version:
repos:
# fixers
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.4
rev: v0.6.5
hooks:
- id: ruff
args: [--fix]
Expand Down
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ hyperlink==21.0.0
# via hatch
hypothesis==6.112.1
# via dycw-utilities (pyproject.toml)
idna==3.8
idna==3.9
# via
# anyio
# httpx
Expand Down Expand Up @@ -254,7 +254,7 @@ pillow==10.4.0
# streamlit
pip==24.2
# via pypiserver
platformdirs==4.3.2
platformdirs==4.3.3
# via
# hatch
# virtualenv
Expand Down Expand Up @@ -427,7 +427,7 @@ urllib3==2.2.3
# via requests
userpath==1.9.2
# via hatch
uv==0.4.9
uv==0.4.10
# via hatch
vegafusion==1.6.9
# via dycw-utilities (pyproject.toml)
Expand All @@ -447,7 +447,7 @@ wrapt==1.16.0
# via deprecated
yarl==1.11.1
# via aiohttp
zipp==3.20.1
zipp==3.20.2
# via importlib-metadata
zstandard==0.23.0
# via hatch
4 changes: 2 additions & 2 deletions requirements/jupyter.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ httpx==0.27.2
# via jupyterlab
hypothesis==6.112.1
# via dycw-utilities (pyproject.toml)
idna==3.8
idna==3.9
# via
# anyio
# httpx
Expand Down Expand Up @@ -169,7 +169,7 @@ parso==0.8.4
# via jedi
pexpect==4.9.0
# via ipython
platformdirs==4.3.2
platformdirs==4.3.3
# via jupyter-core
pluggy==1.5.0
# via pytest
Expand Down
2 changes: 1 addition & 1 deletion requirements/slack-sdk.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ frozenlist==1.4.1
# aiosignal
hypothesis==6.112.1
# via dycw-utilities (pyproject.toml)
idna==3.8
idna==3.9
# via yarl
iniconfig==2.0.0
# via pytest
Expand Down
2 changes: 1 addition & 1 deletion requirements/streamlit.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ gitpython==3.1.43
# via streamlit
hypothesis==6.112.1
# via dycw-utilities (pyproject.toml)
idna==3.8
idna==3.9
# via requests
iniconfig==2.0.0
# via pytest
Expand Down
61 changes: 44 additions & 17 deletions src/tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,34 +442,61 @@ async def test_reverse(


class TestTryAwait:
async def test_maybe_awaitable(self) -> None:
from utilities.random import SYSTEM_RANDOM

if SYSTEM_RANDOM.random() <= 0.5:

async def func(*, value: bool) -> bool: # pyright: ignore[reportRedeclaration]
await sleep(0.01)
return not value

else:

def func(*, value: bool) -> bool:
return not value
async def test_sync(self) -> None:
def func(*, value: bool) -> bool:
return not value

result = await try_await(func(value=True))
assert result is False

async def test_awaitable(self) -> None:
async def test_async(self) -> None:
async def func(*, value: bool) -> bool:
await sleep(0.01)
return not value

result = await try_await(func(value=True))
assert result is False

async def test_sync(self) -> None:
@mark.parametrize("cls", [param(ValueError), param(TypeError)], ids=str)
async def test_error_std_msg_sync(self, *, cls: type[Exception]) -> None:
def func(*, value: bool) -> bool:
if not value:
msg = f"Value must be True; got {value}"
raise cls(msg)
return not value

result = await try_await(func(value=True))
assert result is False
with raises(cls, match="Value must be True; got False"):
_ = await try_await(func(value=False))

@mark.parametrize("cls", [param(ValueError), param(TypeError)], ids=str)
async def test_error_std_msg_async(self, *, cls: type[Exception]) -> None:
async def func(*, value: bool) -> bool:
if not value:
msg = f"Value must be True; got {value}"
raise cls(msg)
await sleep(0.01)
return not value

with raises(cls, match="Value must be True; got False"):
_ = await try_await(func(value=False))

@mark.parametrize("cls", [param(ValueError), param(TypeError)], ids=str)
async def test_error_non_std_msg_sync(self, *, cls: type[Exception]) -> None:
def func(*, value: bool) -> bool:
if not value:
raise cls(value)
return not value

with raises(cls, match="False"):
_ = await try_await(func(value=False))

@mark.parametrize("cls", [param(ValueError), param(TypeError)], ids=str)
async def test_error_non_std_msg_async(self, *, cls: type[Exception]) -> None:
async def func(*, value: bool) -> bool:
if not value:
raise cls(value)
await sleep(0.01)
return not value

with raises(cls, match="False"):
_ = await try_await(func(value=False))
2 changes: 1 addition & 1 deletion src/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import annotations

__version__ = "0.54.2"
__version__ = "0.54.3"
13 changes: 11 additions & 2 deletions src/utilities/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from functools import wraps
from re import search
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, overload

from utilities.functions import ensure_not_none
from utilities.iterables import OneError, one
from utilities.text import EnsureStrError, ensure_str

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Callable, Coroutine
Expand Down Expand Up @@ -251,8 +254,14 @@ async def try_await(obj: MaybeAwaitable[_T], /) -> _T:
"""Try await a value from an object."""
try:
return await cast(Awaitable[_T], obj)
except TypeError:
return cast(_T, obj)
except TypeError as error:
try:
text = ensure_str(one(error.args))
except (EnsureStrError, OneError):
raise error from None
if search("object .* can't be used in 'await' expression", text):
return cast(_T, obj)
raise


__all__ = [
Expand Down

0 comments on commit 70c6ab6

Please sign in to comment.