Skip to content

Commit

Permalink
Add support for lazy async generators (#7)
Browse files Browse the repository at this point in the history
* Add support for lazy async generators

* Assorted lint errors

* Format tests.

* Don't hold the lock during generator resolution

* Black format

* Fix default on __anext__

* Skip test if no Barrier.
  • Loading branch information
mattalbr authored Oct 2, 2023
1 parent ffc5376 commit 873a571
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 11 deletions.
77 changes: 73 additions & 4 deletions once/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,16 @@ def __init__(self, func: collections.abc.Callable):
self.func = self._inspect_function(func)
self.called = False
self.return_value: typing.Any = None

self.is_asyncgen = inspect.isasyncgenfunction(self.func)
if self.is_asyncgen:
raise SyntaxError("async generators are not (yet) supported")
self.asyncgen_finished = False
self.asyncgen_generator = None
self.asyncgen_results: list = []
self.async_generating = False

# Only works for one way generators, not anything that requires send for now.
# Async generators do support send.
self.is_syncgen = inspect.isgeneratorfunction(self.func)
# The function inspect.isawaitable is a bit of a misnomer - it refers
# to the awaitable result of an async function, not the async function
Expand All @@ -57,15 +64,75 @@ def _sync_return(self):
self.return_value.__iter__()
return self.return_value

async def _async_gen_proxy(self, func, *args, **kwargs):
i = 0
send = None
next_val = None

# A copy of self.async_generating that we can access outside of the lock.
async_generating = None

# Indicates that we're tied for the head generator, but someone started generating the next
# result first, so we should just poll until the result is available.
waiting_for_async_generating = False

while True:
if waiting_for_async_generating:
# This is a load bearing sleep. We're waiting for the leader to generate the result, but
# we have control of the lock, so the async with will never yield execution to the event loop,
# so we would loop forever. By awaiting sleep(0), we yield execution which will allow us to
# poll for self.async_generating readiness.
await asyncio.sleep(0)
waiting_for_async_generating = False
async with self.async_lock:
if self.asyncgen_generator is None and not self.asyncgen_finished:
# We're the first! Do some setup.
self.asyncgen_generator = func(*args, **kwargs)

if i == len(self.asyncgen_results) and not self.asyncgen_finished:
if self.async_generating:
# We're at the lead, but someone else is generating the next value
# so we just hop back onto the next iteration of the loop
# until it's ready.
waiting_for_async_generating = True
continue
# We're at the lead and no one else is generating, so we need to increment
# the iterator. We just store the value in self.asyncgen_results so that
# we can later yield it outside of the lock.
self.async_generating = self.asyncgen_generator.asend(send)
async_generating = self.async_generating
elif i == len(self.asyncgen_results) and self.asyncgen_finished:
# All done.
return
else:
# We already have the correct result, so we grab it here to
# yield it outside the lock.
next_val = self.asyncgen_results[i]

if async_generating:
try:
next_val = await async_generating
except StopAsyncIteration:
async with self.async_lock:
self.asyncgen_generator = None # Allow this to be GCed.
self.asyncgen_finished = True
self.async_generating = None
async_generating = None
return
async with self.async_lock:
self.asyncgen_results.append(next_val)
async_generating = None
self.async_generating = None

send = yield next_val
i += 1

async def _execute_call_once_async(self, func: collections.abc.Callable, *args, **kwargs):
if self.called:
return self.return_value
async with self.async_lock:
if self.called:
return self.return_value
# Currently unreachable code - Async iterators are disabled for now.
if self.is_asyncgen:
self.return_value = [i async for i in func(*args, **kwargs)]
else:
self.return_value = await func(*args, **kwargs)
self.called = True
Expand All @@ -86,6 +153,8 @@ def _execute_call_once_sync(self, func: collections.abc.Callable, *args, **kwarg
return self._sync_return()

def _execute_call_once(self, func: collections.abc.Callable, *args, **kwargs):
if self.is_asyncgen:
return self._async_gen_proxy(func, *args, **kwargs)
if self.is_async:
return self._execute_call_once_async(func, *args, **kwargs)
return self._execute_call_once_sync(func, *args, **kwargs)
Expand Down
143 changes: 136 additions & 7 deletions once_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Unit tests for once decorators."""
# pylint: disable=missing-function-docstring
import asyncio
import concurrent.futures
import inspect
import sys
import time
import unittest
from unittest import mock
Expand All @@ -12,6 +14,18 @@
import once


if sys.version_info.minor < 10:
print(f"Redefining anext for python 3.{sys.version_info.minor}")

async def anext(iter, default=StopAsyncIteration):
if default != StopAsyncIteration:
try:
return await iter.__anext__()
except StopAsyncIteration:
return default
return await iter.__anext__()


class Counter:
"""Holding object for a counter.
Expand Down Expand Up @@ -368,19 +382,134 @@ async def counting_fn2():
self.assertEqual(await counting_fn2(), 2)
self.assertEqual(await counting_fn2(), 2)

async def test_inspect_func(self):
@once.once
async def async_func():
return True

# Unfortunately these are corrupted by our @once.once.
# self.assertFalse(inspect.isasyncgenfunction(async_func))
# self.assertTrue(inspect.iscoroutinefunction(async_func))

coroutine = async_func()
self.assertTrue(inspect.iscoroutine(coroutine))
self.assertTrue(inspect.isawaitable(coroutine))
self.assertFalse(inspect.isasyncgen(coroutine))

# Just for cleanup.
await coroutine

async def test_inspect_iterator(self):
@once.once
async def async_yielding_iterator():
for i in range(3):
yield i

# Unfortunately these are corrupted by our @once.once.
# self.assertTrue(inspect.isasyncgenfunction(async_yielding_iterator))
# self.assertTrue(inspect.iscoroutinefunction(async_yielding_iterator))

coroutine = async_yielding_iterator()
self.assertFalse(inspect.iscoroutine(coroutine))
self.assertFalse(inspect.isawaitable(coroutine))
self.assertTrue(inspect.isasyncgen(coroutine))

# Just for cleanup.
async for i in coroutine:
pass

async def test_iterator(self):
counter = Counter()

with self.assertRaises(SyntaxError):
@once.once
async def async_yielding_iterator():
for i in range(3):
yield counter.get_incremented()

@once.once
async def async_yielding_iterator():
self.assertEqual([i async for i in async_yielding_iterator()], [1, 2, 3])
self.assertEqual([i async for i in async_yielding_iterator()], [1, 2, 3])

async def test_iterator_is_lazily_evaluted(self):
counter = Counter()

@once.once
async def async_yielding_iterator():
for i in range(3):
yield counter.get_incremented()
for i in range(3):
yield i

# self.assertEqual([i async for i in async_yielding_iterator()], [1, 0, 1, 2])
# self.assertEqual([i async for i in async_yielding_iterator()], [1, 0, 1, 2])
gen_1 = async_yielding_iterator()
gen_2 = async_yielding_iterator()
gen_3 = async_yielding_iterator()

self.assertEqual(counter.value, 0)
self.assertEqual(await anext(gen_1), 1)
self.assertEqual(await anext(gen_2), 1)
self.assertEqual(await anext(gen_2), 2)
self.assertEqual(await anext(gen_2), 3)
self.assertEqual(await anext(gen_1), 2)
self.assertEqual(await anext(gen_3), 1)
self.assertEqual(await anext(gen_3), 2)
self.assertEqual(await anext(gen_3), 3)
self.assertEqual(await anext(gen_3, None), None)
self.assertEqual(await anext(gen_2, None), None)
self.assertEqual(await anext(gen_1), 3)
self.assertEqual(await anext(gen_2, None), None)

async def test_receiving_iterator(self):
@once.once
async def async_receiving_iterator():
next = yield 1
while next is not None:
next = yield next

gen_1 = async_receiving_iterator()
gen_2 = async_receiving_iterator()
self.assertEqual(await gen_1.asend(None), 1)
self.assertEqual(await gen_1.asend(1), 1)
self.assertEqual(await gen_1.asend(3), 3)
self.assertEqual(await gen_2.asend(None), 1)
self.assertEqual(await gen_2.asend(None), 1)
self.assertEqual(await gen_2.asend(None), 3)
self.assertEqual(await gen_2.asend(5), 5)
self.assertEqual(await anext(gen_2, None), None)
self.assertEqual(await gen_1.asend(None), 5)
self.assertEqual(await anext(gen_1, None), None)

@unittest.skipIf(not hasattr(asyncio, "Barrier"), "Requires Barrier to evaluate")
async def test_iterator_lock_not_held_during_evaluation(self):
counter = Counter()

@once.once
async def async_yielding_iterator():
barrier = yield counter.get_incremented()
while barrier is not None:
await barrier.wait()
barrier = yield counter.get_incremented()

gen_1 = async_yielding_iterator()
gen_2 = async_yielding_iterator()
barrier = asyncio.Barrier(2)
self.assertEqual(await gen_1.asend(None), 1)
task1 = asyncio.create_task(gen_1.asend(barrier))

# Loop until task1 is stuck waiting.
while barrier.n_waiting < 1:
await asyncio.sleep(0)

self.assertEqual(
await gen_2.asend(None), 1
) # Should return immediately even though task1 is stuck.

# .asend("None") should be ignored because task1 has already started,
# so task2 should still return 2 instead of ending iteration.
task2 = asyncio.create_task(gen_2.asend(None))

await barrier.wait()

self.assertEqual(await task1, 2)
self.assertEqual(await task2, 2)
self.assertEqual(await anext(gen_1, None), None)
self.assertEqual(await anext(gen_2, None), None)

async def test_once_per_class(self):
class _CallOnceClass(Counter):
Expand Down

0 comments on commit 873a571

Please sign in to comment.