Skip to content

Commit

Permalink
Add start_generator_coroutine and send_and_next (and async versio…
Browse files Browse the repository at this point in the history
…ns) (#729)
  • Loading branch information
dycw authored Sep 12, 2024
1 parent d6db439 commit bbb9828
Show file tree
Hide file tree
Showing 15 changed files with 258 additions and 56 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ tzdata==2024.1
# pandas
tzlocal==5.2
# via dycw-utilities (pyproject.toml)
urllib3==2.2.2
urllib3==2.2.3
# via requests
userpath==1.9.2
# via hatch
Expand Down
2 changes: 1 addition & 1 deletion requirements/jupyter.txt
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ tzdata==2024.1
# via pandas
uri-template==1.3.0
# via jsonschema
urllib3==2.2.2
urllib3==2.2.3
# via requests
wcwidth==0.2.13
# via prompt-toolkit
Expand Down
2 changes: 1 addition & 1 deletion requirements/streamlit.txt
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,5 @@ typing-extensions==4.12.2
# streamlit
tzdata==2024.1
# via pandas
urllib3==2.2.2
urllib3==2.2.3
# via requests
77 changes: 75 additions & 2 deletions src/tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING, Any, ClassVar

from hypothesis import given
from pytest import mark, param, raises
from pytest import CaptureFixture, mark, param, raises

from utilities.asyncio import (
ReduceAsyncError,
Expand All @@ -16,16 +16,25 @@
groupby_async_list,
is_awaitable,
reduce_async,
send_and_next_async,
start_async_generator_coroutine,
timeout_dur,
to_list,
to_set,
to_sorted,
try_await,
)
from utilities.functions import ensure_not_none
from utilities.hypothesis import durations

if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterable, Iterator, Sequence
from collections.abc import (
AsyncGenerator,
AsyncIterator,
Iterable,
Iterator,
Sequence,
)

from utilities.types import Duration

Expand Down Expand Up @@ -257,6 +266,70 @@ async def add(x: int, y: int, /) -> int:
_ = await reduce_async(add, [])


class TestSendAndNextAsync:
async def test_main(self, *, capsys: CaptureFixture) -> None:
@start_async_generator_coroutine
async def func() -> AsyncGenerator[int | None, float | None]:
print("Initial") # noqa: T201
while True:
input_ = yield
output = round(ensure_not_none(input_))
if output >= 0:
print(f"Received {input_}, yielding {output}") # noqa: T201
yield output
else:
break

generator = await func()
out = capsys.readouterr().out
assert out == "Initial\n", out
result = await send_and_next_async(0.1, generator)
assert result == 0
out = capsys.readouterr().out
assert out == "Received 0.1, yielding 0\n", out
result = await send_and_next_async(0.9, generator)
assert result == 1
out = capsys.readouterr().out
assert out == "Received 0.9, yielding 1\n", out
result = await send_and_next_async(1.1, generator)
assert result == 1
out = capsys.readouterr().out
assert out == "Received 1.1, yielding 1\n", out
with raises(StopAsyncIteration) as exc:
_ = await send_and_next_async(-0.9, generator)
assert exc.value.args == ()


class TestStartAsyncGeneratorCoroutine:
async def test_main(self, *, capsys: CaptureFixture) -> None:
@start_async_generator_coroutine
async def func() -> AsyncGenerator[int, float]:
print("Pre-initial") # noqa: T201
x = yield 0
print(f"Post-initial; x={x}") # noqa: T201
while x >= 0:
print(f"Pre-yield; x={x}") # noqa: T201
x = yield round(x)
print(f"Post-yield; x={x}") # noqa: T201
await sleep(0.01)

generator = await func()
out = capsys.readouterr().out
assert out == "Pre-initial\n", out
assert await generator.asend(0.1) == 0
out = capsys.readouterr().out
assert out == "Post-initial; x=0.1\nPre-yield; x=0.1\n", out
assert await generator.asend(0.9) == 1
out = capsys.readouterr().out
assert out == "Post-yield; x=0.9\nPre-yield; x=0.9\n", out
assert await generator.asend(1.1) == 1
out = capsys.readouterr().out
assert out == "Post-yield; x=1.1\nPre-yield; x=1.1\n", out
with raises(StopAsyncIteration) as exc:
_ = await generator.asend(-0.9)
assert exc.value.args == ()


class TestTimeoutDur:
@given(duration=durations())
async def test_main(self, *, duration: Duration) -> None:
Expand Down
71 changes: 69 additions & 2 deletions src/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

from hypothesis import given
from hypothesis.strategies import booleans, integers
from pytest import mark, param
from pytest import CaptureFixture, mark, param, raises

from utilities.asyncio import try_await
from utilities.functions import (
ensure_not_none,
first,
get_class,
get_class_name,
Expand All @@ -21,10 +22,12 @@
is_not_none,
not_func,
second,
send_and_next,
start_generator_coroutine,
)

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Callable, Generator

_T = TypeVar("_T")

Expand Down Expand Up @@ -156,3 +159,67 @@ class TestSecond:
def test_main(self, *, x: int, y: int) -> None:
pair = x, y
assert second(pair) == y


class TestSendAndNext:
def test_main(self, *, capsys: CaptureFixture) -> None:
@start_generator_coroutine
def func() -> Generator[int | None, float | None, str]:
print("Initial") # noqa: T201
while True:
input_ = yield
output = round(ensure_not_none(input_))
if output >= 0:
print(f"Received {input_}, yielding {output}") # noqa: T201
yield output
else:
return "Done"

generator = func()
out = capsys.readouterr().out
assert out == "Initial\n", out
result = send_and_next(0.1, generator)
assert result == 0
out = capsys.readouterr().out
assert out == "Received 0.1, yielding 0\n", out
result = send_and_next(0.9, generator)
assert result == 1
out = capsys.readouterr().out
assert out == "Received 0.9, yielding 1\n", out
result = send_and_next(1.1, generator)
assert result == 1
out = capsys.readouterr().out
assert out == "Received 1.1, yielding 1\n", out
with raises(StopIteration) as exc:
_ = send_and_next(-0.9, generator)
assert exc.value.args == ("Done",)


class TestStartGeneratorCoroutine:
def test_main(self, *, capsys: CaptureFixture) -> None:
@start_generator_coroutine
def func() -> Generator[int, float, str]:
print("Pre-initial") # noqa: T201
x = yield 0
print(f"Post-initial; x={x}") # noqa: T201
while x >= 0:
print(f"Pre-yield; x={x}") # noqa: T201
x = yield round(x)
print(f"Post-yield; x={x}") # noqa: T201
return "Done"

generator = func()
out = capsys.readouterr().out
assert out == "Pre-initial\n", out
assert generator.send(0.1) == 0
out = capsys.readouterr().out
assert out == "Post-initial; x=0.1\nPre-yield; x=0.1\n", out
assert generator.send(0.9) == 1
out = capsys.readouterr().out
assert out == "Post-yield; x=0.9\nPre-yield; x=0.9\n", out
assert generator.send(1.1) == 1
out = capsys.readouterr().out
assert out == "Post-yield; x=1.1\nPre-yield; x=1.1\n", out
with raises(StopIteration) as exc:
_ = generator.send(-0.9)
assert exc.value.args == ("Done",)
14 changes: 14 additions & 0 deletions src/tests/test_functools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

from operator import sub
from typing import cast

from hypothesis import given
from hypothesis.strategies import booleans, integers
from pytest import raises

from utilities.functions import EnsureNotNoneError, ensure_not_none
from utilities.functools import cache, lru_cache, partial


Expand All @@ -23,6 +26,17 @@ def func(x: int, /) -> int:
assert counter == 1


class TestEnsureNotNone:
def test_main(self) -> None:
maybe_int = cast(int | None, 0)
result = ensure_not_none(maybe_int)
assert result == 0

def test_error(self) -> None:
with raises(EnsureNotNoneError, match="Object .* must not be None"):
_ = ensure_not_none(None)


class TestLRUCache:
def test_no_arguments(self) -> None:
counter = 0
Expand Down
15 changes: 1 addition & 14 deletions src/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import datetime as dt
from pathlib import Path
from typing import Any, cast
from typing import Any

from pytest import mark, param, raises

Expand All @@ -18,7 +18,6 @@
EnsureHashableError,
EnsureIntError,
EnsureMemberError,
EnsureNotNoneError,
EnsureNumberError,
EnsureSizedError,
EnsureSizedNotStrError,
Expand All @@ -33,7 +32,6 @@
ensure_hashable,
ensure_int,
ensure_member,
ensure_not_none,
ensure_number,
ensure_sized,
ensure_sized_not_str,
Expand Down Expand Up @@ -216,17 +214,6 @@ def test_error(self, *, nullable: bool, match: str) -> None:
_ = ensure_member(sentinel, {True, False}, nullable=nullable)


class TestEnsureNotNone:
def test_main(self) -> None:
maybe_int = cast(int | None, 0)
result = ensure_not_none(maybe_int)
assert result == 0

def test_error(self) -> None:
with raises(EnsureNotNoneError, match="Object .* must not be None"):
_ = ensure_not_none(None)


class TestEnsureNumber:
@mark.parametrize(
("obj", "nullable"),
Expand Down
2 changes: 1 addition & 1 deletion src/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import annotations

__version__ = "0.53.3"
__version__ = "0.54.0"
37 changes: 35 additions & 2 deletions src/utilities/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from __future__ import annotations

from functools import wraps
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, overload

from utilities.functions import ensure_not_none

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Callable, Coroutine

from asyncio import timeout
from collections.abc import (
AsyncIterable,
Expand All @@ -12,7 +20,7 @@
)
from dataclasses import dataclass
from itertools import groupby
from typing import TYPE_CHECKING, Any, TypeGuard, TypeVar, cast, overload
from typing import TYPE_CHECKING, TypeGuard

from typing_extensions import override

Expand All @@ -25,7 +33,7 @@

from utilities.types import Duration


_P = ParamSpec("_P")
_T = TypeVar("_T")
_U = TypeVar("_U")
Coroutine1 = Coroutine[Any, Any, _T]
Expand Down Expand Up @@ -156,6 +164,29 @@ def __str__(self) -> str:
return f"Empty iterable {self.iterable} with no initial value"


async def send_and_next_async(
value: _U, generator: AsyncGenerator[_T | None, _U | None], /
) -> _T:
"""Send a value to a generator, and then yield the output."""
result = ensure_not_none(await generator.asend(value))
_ = await anext(generator)
return result


def start_async_generator_coroutine(
func: Callable[_P, AsyncGenerator[_T, _U]], /
) -> Callable[_P, Coroutine1[AsyncGenerator[_T, _U]]]:
"""Instantiate and then start a generator-coroutine."""

@wraps(func)
async def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> AsyncGenerator[_T, _U]:
coro = func(*args, **kwargs)
_ = await anext(coro)
return coro

return wrapped


def timeout_dur(*, duration: Duration | None = None) -> Timeout:
"""Timeout context manager which accepts durations."""
delay = None if duration is None else duration_to_float(duration)
Expand Down Expand Up @@ -233,6 +264,8 @@ async def try_await(obj: MaybeAwaitable[_T], /) -> _T:
"groupby_async_list",
"is_awaitable",
"reduce_async",
"send_and_next_async",
"start_async_generator_coroutine",
"timeout_dur",
"to_list",
"to_set",
Expand Down
4 changes: 3 additions & 1 deletion src/utilities/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

from typing_extensions import override

from utilities.functions import ensure_not_none
from utilities.platform import SYSTEM, System
from utilities.types import Duration, ensure_not_none
from utilities.zoneinfo import (
UTC,
HongKong,
Expand All @@ -31,6 +31,8 @@
from collections.abc import Iterator
from zoneinfo import ZoneInfo

from utilities.types import Duration


_DAYS_PER_YEAR = 365.25
_MICROSECONDS_PER_MILLISECOND = int(1e3)
Expand Down
Loading

0 comments on commit bbb9828

Please sign in to comment.