diff --git a/requirements.txt b/requirements.txt index 54e57354d..e236e1ba1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -423,7 +423,7 @@ tzdata==2024.1 # pandas tzlocal==5.2 # via dycw-utilities (pyproject.toml) -urllib3==2.2.2 +urllib3==2.2.3 # via requests userpath==1.9.2 # via hatch diff --git a/requirements/jupyter.txt b/requirements/jupyter.txt index 15db4a16e..78df39259 100644 --- a/requirements/jupyter.txt +++ b/requirements/jupyter.txt @@ -302,7 +302,7 @@ tzdata==2024.1 # via pandas uri-template==1.3.0 # via jsonschema -urllib3==2.2.2 +urllib3==2.2.3 # via requests wcwidth==0.2.13 # via prompt-toolkit diff --git a/requirements/streamlit.txt b/requirements/streamlit.txt index adea76f49..4ab845af4 100644 --- a/requirements/streamlit.txt +++ b/requirements/streamlit.txt @@ -125,5 +125,5 @@ typing-extensions==4.12.2 # streamlit tzdata==2024.1 # via pandas -urllib3==2.2.2 +urllib3==2.2.3 # via requests diff --git a/src/tests/test_asyncio.py b/src/tests/test_asyncio.py index c5b3d4142..016dc47f7 100644 --- a/src/tests/test_asyncio.py +++ b/src/tests/test_asyncio.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, ClassVar from hypothesis import given -from pytest import mark, param, raises +from pytest import CaptureFixture, mark, param, raises from utilities.asyncio import ( ReduceAsyncError, @@ -16,16 +16,25 @@ groupby_async_list, is_awaitable, reduce_async, + send_and_next_async, + start_async_generator_coroutine, timeout_dur, to_list, to_set, to_sorted, try_await, ) +from utilities.functions import ensure_not_none from utilities.hypothesis import durations if TYPE_CHECKING: - from collections.abc import AsyncIterator, Iterable, Iterator, Sequence + from collections.abc import ( + AsyncGenerator, + AsyncIterator, + Iterable, + Iterator, + Sequence, + ) from utilities.types import Duration @@ -257,6 +266,70 @@ async def add(x: int, y: int, /) -> int: _ = await reduce_async(add, []) +class TestSendAndNextAsync: + async def test_main(self, *, capsys: CaptureFixture) -> None: + @start_async_generator_coroutine + async def func() -> AsyncGenerator[int | None, float | None]: + print("Initial") # noqa: T201 + while True: + input_ = yield + output = round(ensure_not_none(input_)) + if output >= 0: + print(f"Received {input_}, yielding {output}") # noqa: T201 + yield output + else: + break + + generator = await func() + out = capsys.readouterr().out + assert out == "Initial\n", out + result = await send_and_next_async(0.1, generator) + assert result == 0 + out = capsys.readouterr().out + assert out == "Received 0.1, yielding 0\n", out + result = await send_and_next_async(0.9, generator) + assert result == 1 + out = capsys.readouterr().out + assert out == "Received 0.9, yielding 1\n", out + result = await send_and_next_async(1.1, generator) + assert result == 1 + out = capsys.readouterr().out + assert out == "Received 1.1, yielding 1\n", out + with raises(StopAsyncIteration) as exc: + _ = await send_and_next_async(-0.9, generator) + assert exc.value.args == () + + +class TestStartAsyncGeneratorCoroutine: + async def test_main(self, *, capsys: CaptureFixture) -> None: + @start_async_generator_coroutine + async def func() -> AsyncGenerator[int, float]: + print("Pre-initial") # noqa: T201 + x = yield 0 + print(f"Post-initial; x={x}") # noqa: T201 + while x >= 0: + print(f"Pre-yield; x={x}") # noqa: T201 + x = yield round(x) + print(f"Post-yield; x={x}") # noqa: T201 + await sleep(0.01) + + generator = await func() + out = capsys.readouterr().out + assert out == "Pre-initial\n", out + assert await generator.asend(0.1) == 0 + out = capsys.readouterr().out + assert out == "Post-initial; x=0.1\nPre-yield; x=0.1\n", out + assert await generator.asend(0.9) == 1 + out = capsys.readouterr().out + assert out == "Post-yield; x=0.9\nPre-yield; x=0.9\n", out + assert await generator.asend(1.1) == 1 + out = capsys.readouterr().out + assert out == "Post-yield; x=1.1\nPre-yield; x=1.1\n", out + with raises(StopAsyncIteration) as exc: + _ = await generator.asend(-0.9) + assert exc.value.args == () + + class TestTimeoutDur: @given(duration=durations()) async def test_main(self, *, duration: Duration) -> None: diff --git a/src/tests/test_functions.py b/src/tests/test_functions.py index 1ad215b4c..bdd808c3a 100644 --- a/src/tests/test_functions.py +++ b/src/tests/test_functions.py @@ -8,10 +8,11 @@ from hypothesis import given from hypothesis.strategies import booleans, integers -from pytest import mark, param +from pytest import CaptureFixture, mark, param, raises from utilities.asyncio import try_await from utilities.functions import ( + ensure_not_none, first, get_class, get_class_name, @@ -21,10 +22,12 @@ is_not_none, not_func, second, + send_and_next, + start_generator_coroutine, ) if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Generator _T = TypeVar("_T") @@ -156,3 +159,67 @@ class TestSecond: def test_main(self, *, x: int, y: int) -> None: pair = x, y assert second(pair) == y + + +class TestSendAndNext: + def test_main(self, *, capsys: CaptureFixture) -> None: + @start_generator_coroutine + def func() -> Generator[int | None, float | None, str]: + print("Initial") # noqa: T201 + while True: + input_ = yield + output = round(ensure_not_none(input_)) + if output >= 0: + print(f"Received {input_}, yielding {output}") # noqa: T201 + yield output + else: + return "Done" + + generator = func() + out = capsys.readouterr().out + assert out == "Initial\n", out + result = send_and_next(0.1, generator) + assert result == 0 + out = capsys.readouterr().out + assert out == "Received 0.1, yielding 0\n", out + result = send_and_next(0.9, generator) + assert result == 1 + out = capsys.readouterr().out + assert out == "Received 0.9, yielding 1\n", out + result = send_and_next(1.1, generator) + assert result == 1 + out = capsys.readouterr().out + assert out == "Received 1.1, yielding 1\n", out + with raises(StopIteration) as exc: + _ = send_and_next(-0.9, generator) + assert exc.value.args == ("Done",) + + +class TestStartGeneratorCoroutine: + def test_main(self, *, capsys: CaptureFixture) -> None: + @start_generator_coroutine + def func() -> Generator[int, float, str]: + print("Pre-initial") # noqa: T201 + x = yield 0 + print(f"Post-initial; x={x}") # noqa: T201 + while x >= 0: + print(f"Pre-yield; x={x}") # noqa: T201 + x = yield round(x) + print(f"Post-yield; x={x}") # noqa: T201 + return "Done" + + generator = func() + out = capsys.readouterr().out + assert out == "Pre-initial\n", out + assert generator.send(0.1) == 0 + out = capsys.readouterr().out + assert out == "Post-initial; x=0.1\nPre-yield; x=0.1\n", out + assert generator.send(0.9) == 1 + out = capsys.readouterr().out + assert out == "Post-yield; x=0.9\nPre-yield; x=0.9\n", out + assert generator.send(1.1) == 1 + out = capsys.readouterr().out + assert out == "Post-yield; x=1.1\nPre-yield; x=1.1\n", out + with raises(StopIteration) as exc: + _ = generator.send(-0.9) + assert exc.value.args == ("Done",) diff --git a/src/tests/test_functools.py b/src/tests/test_functools.py index 3070a1d73..c2e994aa5 100644 --- a/src/tests/test_functools.py +++ b/src/tests/test_functools.py @@ -1,10 +1,13 @@ from __future__ import annotations from operator import sub +from typing import cast from hypothesis import given from hypothesis.strategies import booleans, integers +from pytest import raises +from utilities.functions import EnsureNotNoneError, ensure_not_none from utilities.functools import cache, lru_cache, partial @@ -23,6 +26,17 @@ def func(x: int, /) -> int: assert counter == 1 +class TestEnsureNotNone: + def test_main(self) -> None: + maybe_int = cast(int | None, 0) + result = ensure_not_none(maybe_int) + assert result == 0 + + def test_error(self) -> None: + with raises(EnsureNotNoneError, match="Object .* must not be None"): + _ = ensure_not_none(None) + + class TestLRUCache: def test_no_arguments(self) -> None: counter = 0 diff --git a/src/tests/test_types.py b/src/tests/test_types.py index 32ec7b029..013c4cdeb 100644 --- a/src/tests/test_types.py +++ b/src/tests/test_types.py @@ -2,7 +2,7 @@ import datetime as dt from pathlib import Path -from typing import Any, cast +from typing import Any from pytest import mark, param, raises @@ -18,7 +18,6 @@ EnsureHashableError, EnsureIntError, EnsureMemberError, - EnsureNotNoneError, EnsureNumberError, EnsureSizedError, EnsureSizedNotStrError, @@ -33,7 +32,6 @@ ensure_hashable, ensure_int, ensure_member, - ensure_not_none, ensure_number, ensure_sized, ensure_sized_not_str, @@ -216,17 +214,6 @@ def test_error(self, *, nullable: bool, match: str) -> None: _ = ensure_member(sentinel, {True, False}, nullable=nullable) -class TestEnsureNotNone: - def test_main(self) -> None: - maybe_int = cast(int | None, 0) - result = ensure_not_none(maybe_int) - assert result == 0 - - def test_error(self) -> None: - with raises(EnsureNotNoneError, match="Object .* must not be None"): - _ = ensure_not_none(None) - - class TestEnsureNumber: @mark.parametrize( ("obj", "nullable"), diff --git a/src/utilities/__init__.py b/src/utilities/__init__.py index d613f4298..262a3af47 100644 --- a/src/utilities/__init__.py +++ b/src/utilities/__init__.py @@ -1,3 +1,3 @@ from __future__ import annotations -__version__ = "0.53.3" +__version__ = "0.54.0" diff --git a/src/utilities/asyncio.py b/src/utilities/asyncio.py index fd310dfe4..d2e7641db 100644 --- a/src/utilities/asyncio.py +++ b/src/utilities/asyncio.py @@ -1,5 +1,13 @@ from __future__ import annotations +from functools import wraps +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, overload + +from utilities.functions import ensure_not_none + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Callable, Coroutine + from asyncio import timeout from collections.abc import ( AsyncIterable, @@ -12,7 +20,7 @@ ) from dataclasses import dataclass from itertools import groupby -from typing import TYPE_CHECKING, Any, TypeGuard, TypeVar, cast, overload +from typing import TYPE_CHECKING, TypeGuard from typing_extensions import override @@ -25,7 +33,7 @@ from utilities.types import Duration - +_P = ParamSpec("_P") _T = TypeVar("_T") _U = TypeVar("_U") Coroutine1 = Coroutine[Any, Any, _T] @@ -156,6 +164,29 @@ def __str__(self) -> str: return f"Empty iterable {self.iterable} with no initial value" +async def send_and_next_async( + value: _U, generator: AsyncGenerator[_T | None, _U | None], / +) -> _T: + """Send a value to a generator, and then yield the output.""" + result = ensure_not_none(await generator.asend(value)) + _ = await anext(generator) + return result + + +def start_async_generator_coroutine( + func: Callable[_P, AsyncGenerator[_T, _U]], / +) -> Callable[_P, Coroutine1[AsyncGenerator[_T, _U]]]: + """Instantiate and then start a generator-coroutine.""" + + @wraps(func) + async def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> AsyncGenerator[_T, _U]: + coro = func(*args, **kwargs) + _ = await anext(coro) + return coro + + return wrapped + + def timeout_dur(*, duration: Duration | None = None) -> Timeout: """Timeout context manager which accepts durations.""" delay = None if duration is None else duration_to_float(duration) @@ -233,6 +264,8 @@ async def try_await(obj: MaybeAwaitable[_T], /) -> _T: "groupby_async_list", "is_awaitable", "reduce_async", + "send_and_next_async", + "start_async_generator_coroutine", "timeout_dur", "to_list", "to_set", diff --git a/src/utilities/datetime.py b/src/utilities/datetime.py index 1c65c6750..d2f2d5a5c 100644 --- a/src/utilities/datetime.py +++ b/src/utilities/datetime.py @@ -17,8 +17,8 @@ from typing_extensions import override +from utilities.functions import ensure_not_none from utilities.platform import SYSTEM, System -from utilities.types import Duration, ensure_not_none from utilities.zoneinfo import ( UTC, HongKong, @@ -31,6 +31,8 @@ from collections.abc import Iterator from zoneinfo import ZoneInfo + from utilities.types import Duration + _DAYS_PER_YEAR = 365.25 _MICROSECONDS_PER_MILLISECOND = int(1e3) diff --git a/src/utilities/functions.py b/src/utilities/functions.py index ad38d61b3..0c96a63d0 100644 --- a/src/utilities/functions.py +++ b/src/utilities/functions.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from functools import partial, wraps from types import ( BuiltinFunctionType, @@ -11,15 +12,31 @@ ) from typing import TYPE_CHECKING, Any, TypeVar, overload -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, override if TYPE_CHECKING: - from collections.abc import Callable - + from collections.abc import Callable, Generator _P = ParamSpec("_P") _T = TypeVar("_T") _U = TypeVar("_U") +_V = TypeVar("_V") + + +def ensure_not_none(obj: _T | None, /) -> _T: + """Ensure an object is not None.""" + if obj is None: + raise EnsureNotNoneError(obj=obj) + return obj + + +@dataclass(kw_only=True) +class EnsureNotNoneError(Exception): + obj: Any + + @override + def __str__(self) -> str: + return f"Object {self.obj} must not be None" def first(pair: tuple[_T, Any], /) -> _T: @@ -84,7 +101,30 @@ def second(pair: tuple[Any, _U], /) -> _U: return pair[1] +def send_and_next(value: _U, generator: Generator[_T | None, _U | None, Any], /) -> _T: + """Send a value to a generator, and then yield the output.""" + result = ensure_not_none(generator.send(value)) + _ = next(generator) + return result + + +def start_generator_coroutine( + func: Callable[_P, Generator[_T, _U, _V]], / +) -> Callable[_P, Generator[_T, _U, _V]]: + """Instantiate and then start a generator-coroutine.""" + + @wraps(func) + def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> Generator[_T, _U, _V]: + coro = func(*args, **kwargs) + _ = next(coro) + return coro + + return wrapped + + __all__ = [ + "EnsureNotNoneError", + "ensure_not_none", "first", "get_class", "get_class_name", @@ -94,4 +134,6 @@ def second(pair: tuple[Any, _U], /) -> _U: "is_not_none", "not_func", "second", + "send_and_next", + "start_generator_coroutine", ] diff --git a/src/utilities/hypothesis.py b/src/utilities/hypothesis.py index cb1247341..26e6cbd21 100644 --- a/src/utilities/hypothesis.py +++ b/src/utilities/hypothesis.py @@ -62,7 +62,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine from utilities.numpy import NDArrayB, NDArrayF, NDArrayI, NDArrayO - from utilities.redis import _RedisContainer + from utilities.redis import RedisContainer from utilities.types import Duration, Number @@ -417,7 +417,7 @@ def random_states( def redis_cms( data: DataObject, /, *, async_: MaybeSearchStrategy[bool] = _ASYNCS ) -> AbstractAsyncContextManager[ - _RedisContainer[redis.Redis] | _RedisContainer[redis.asyncio.Redis] + RedisContainer[redis.Redis] | RedisContainer[redis.asyncio.Redis] ]: """Strategy for generating redis clients (with cleanup).""" import redis # skipif-ci-and-not-linux @@ -425,7 +425,7 @@ def redis_cms( from redis.exceptions import ResponseError # skipif-ci-and-not-linux from redis.typing import KeyT # skipif-ci-and-not-linux - from utilities.redis import _RedisContainer # skipif-ci-and-not-linux + from utilities.redis import RedisContainer # skipif-ci-and-not-linux draw = lift_data(data) # skipif-ci-and-not-linux now = get_now(time_zone="local") # skipif-ci-and-not-linux @@ -436,13 +436,13 @@ def redis_cms( @asynccontextmanager async def yield_redis_async() -> ( # skipif-ci-and-not-linux - AsyncIterator[_RedisContainer[redis.asyncio.Redis]] + AsyncIterator[RedisContainer[redis.asyncio.Redis]] ): async with redis.asyncio.Redis(db=15) as client: # skipif-ci-and-not-linux keys = cast(list[KeyT], await client.keys(pattern=f"{key}_*")) with suppress(ResponseError): _ = await client.delete(*keys) - yield _RedisContainer(client=client, timestamp=now, uuid=uuid, key=key) + yield RedisContainer(client=client, timestamp=now, uuid=uuid, key=key) keys = cast(list[KeyT], await client.keys(pattern=f"{key}_*")) with suppress(ResponseError): _ = await client.delete(*keys) @@ -451,13 +451,13 @@ async def yield_redis_async() -> ( # skipif-ci-and-not-linux @asynccontextmanager async def yield_redis_sync() -> ( # skipif-ci-and-not-linux - AsyncIterator[_RedisContainer[redis.Redis]] + AsyncIterator[RedisContainer[redis.Redis]] ): with redis.Redis(db=15) as client: # skipif-ci-and-not-linux keys = cast(list[KeyT], client.keys(pattern=f"{key}_*")) with suppress(ResponseError): _ = client.delete(*keys) - yield _RedisContainer(client=client, timestamp=now, uuid=uuid, key=key) + yield RedisContainer(client=client, timestamp=now, uuid=uuid, key=key) keys = cast(list[KeyT], client.keys(pattern=f"{key}_*")) with suppress(ResponseError): _ = client.delete(*keys) diff --git a/src/utilities/redis.py b/src/utilities/redis.py index 3bdb4f5f2..ebac61dd3 100644 --- a/src/utilities/redis.py +++ b/src/utilities/redis.py @@ -57,7 +57,7 @@ @dataclass(repr=False, frozen=True, kw_only=True) -class _RedisContainer(Generic[_TRedis]): +class RedisContainer(Generic[_TRedis]): """A container for a client; for testing purposes only.""" client: _TRedis @@ -1613,6 +1613,7 @@ def _classify_response_error(error: ResponseError, /) -> _ResponseErrorKind: __all__ = [ + "RedisContainer", "TimeSeriesAddDataFrameError", "TimeSeriesAddError", "TimeSeriesMAddError", diff --git a/src/utilities/sqlalchemy_polars.py b/src/utilities/sqlalchemy_polars.py index 1d1e1872c..84b77dcc9 100644 --- a/src/utilities/sqlalchemy_polars.py +++ b/src/utilities/sqlalchemy_polars.py @@ -33,7 +33,7 @@ from utilities.datetime import is_subclass_date_not_datetime from utilities.errors import redirect_error -from utilities.functions import identity +from utilities.functions import ensure_not_none, identity from utilities.iterables import ( CheckDuplicatesError, OneError, @@ -61,7 +61,6 @@ upsert_items_async, yield_connection, ) -from utilities.types import StrMapping, ensure_not_none from utilities.zoneinfo import UTC if TYPE_CHECKING: @@ -78,6 +77,8 @@ from sqlalchemy.sql import ColumnCollection from sqlalchemy.sql.base import ReadOnlyColumnCollection + from utilities.types import StrMapping + def insert_dataframe( df: DataFrame, diff --git a/src/utilities/types.py b/src/utilities/types.py index 48d2733ed..353d9e6d5 100644 --- a/src/utilities/types.py +++ b/src/utilities/types.py @@ -269,22 +269,6 @@ def __str__(self) -> str: return f"Object {self.obj} must be a member of {self.container}{desc}" -def ensure_not_none(obj: _T | None, /) -> _T: - """Ensure an object is not None.""" - if obj is None: - raise EnsureNotNoneError(obj=obj) - return obj - - -@dataclass(kw_only=True) -class EnsureNotNoneError(Exception): - obj: Any - - @override - def __str__(self) -> str: - return f"Object {self.obj} must not be None" - - @overload def ensure_number(obj: Any, /, *, nullable: bool) -> Number | None: ... @overload @@ -420,7 +404,6 @@ def inner(obj: Any, /) -> TypeGuard[_T]: "EnsureHashableError", "EnsureIntError", "EnsureMemberError", - "EnsureNotNoneError", "EnsureNumberError", "EnsureSizedError", "EnsureSizedNotStrError", @@ -435,7 +418,6 @@ def inner(obj: Any, /) -> TypeGuard[_T]: "ensure_hashable", "ensure_int", "ensure_member", - "ensure_not_none", "ensure_number", "ensure_sized", "ensure_sized_not_str",