Skip to content

Commit

Permalink
Simplify keep-alive logic
Browse files Browse the repository at this point in the history
  • Loading branch information
dhirving committed Aug 29, 2024
1 parent dd04dde commit 1d0c2f8
Showing 1 changed file with 31 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,38 @@ async def _stream_query_pages(
# Ensure that the database connection is cleaned up by taking control of
# exit_stack.
async with contextmanager_in_threadpool(exit_stack):
iterator = iterate_in_threadpool(_retrieve_query_pages(ctx, spec))
done = False
while not done:
# Read the next value from the iterator, possibly with some
# additional keep-alive messages if it takes a long time.
async for message in _fetch_next_with_keepalives(iterator):
# `None` signals that there is no more data to send.
queue = asyncio.Queue[QueryExecuteResultData | None](1)
async with asyncio.TaskGroup() as tg:
tg.create_task(_enqueue_query_pages(ctx, spec, queue))
async for message in _dequeue_query_pages_with_keepalive(queue):
yield message.model_dump_json() + "\n"


async def _enqueue_query_pages(ctx: _QueryContext, spec: ResultSpec, queue: asyncio.Queue) -> None:
async for page in iterate_in_threadpool(_retrieve_query_pages(ctx, spec)):
await queue.put(page)

# Signal that there is no more data to read.
await queue.put(None)


async def _dequeue_query_pages_with_keepalive(
queue: asyncio.Queue[QueryExecuteResultData | None],
) -> AsyncIterator[QueryExecuteResultData]:
"""Read and return messages from the given queue until the end-of-stream
message `None` is reached. If the producer is taking a long time, returns
a keep-alive message every 15 seconds while we are waiting.
"""
while True:
try:
async with asyncio.timeout(15):
message = await queue.get()
if message is None:
done = True
else:
yield message.model_dump_json()
yield "\n"
return
yield message
except TimeoutError:
yield QueryKeepAliveModel()


def _retrieve_query_pages(ctx: _QueryContext, spec: ResultSpec) -> Iterator[QueryExecuteResultData]:
Expand All @@ -153,34 +174,6 @@ def _retrieve_query_pages(ctx: _QueryContext, spec: ResultSpec) -> Iterator[Quer
yield QueryErrorResultModel(error=serialize_butler_user_error(e))


async def _fetch_next_with_keepalives(
iterator: AsyncIterator[QueryExecuteResultData],
) -> AsyncIterator[QueryExecuteResultData | None]:
"""Read the next value from the given iterator and yield it. Yields a
keep-alive message every 15 seconds while waiting for the iterator to
return a value. Yields `None` if there is nothing left to read from the
iterator.
"""
try:
future = asyncio.ensure_future(anext(iterator, None))
ready = False
while not ready:
(finished_task, pending_task) = await asyncio.wait([future], timeout=15)
if pending_task:
# Hit the timeout, send a keep-alive and keep waiting.
yield QueryKeepAliveModel()
else:
# The next value from the iterator is ready to read.
ready = True
finally:
# Even if we get cancelled above, we need to wait for this iteration to
# complete so we don't have a dangling thread using a database
# connection that the caller is about to clean up.
result = await future

yield result


@query_router.post(
"/v1/query/count",
summary="Query the Butler database and return a count of rows that would be returned.",
Expand Down

0 comments on commit 1d0c2f8

Please sign in to comment.