Skip to content

Commit

Permalink
Simplify query resource management in server
Browse files Browse the repository at this point in the history
  • Loading branch information
dhirving committed Aug 29, 2024
1 parent 670ba4d commit 3d29584
Showing 1 changed file with 54 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

import asyncio
from collections.abc import AsyncIterator, Iterator
from contextlib import ExitStack, contextmanager
from contextlib import contextmanager
from typing import NamedTuple

from fastapi import APIRouter, Depends
Expand Down Expand Up @@ -63,87 +63,73 @@


@query_router.post("/v1/query/execute", summary="Query the Butler database and return full results")
def query_execute(
async def query_execute(
request: QueryExecuteRequestModel, factory: Factory = Depends(factory_dependency)
) -> StreamingResponse:
# Managing the lifetime of the query context object is a little tricky. We
# need to enter the context here, so that we can immediately deal with any
# exceptions raised by query set-up. We eventually transfer control to an
# iterator consumed by FastAPI's StreamingResponse handler, which will
# start iterating after this function returns. So we use this ExitStack
# instance to hand over the context manager to the iterator.
with ExitStack() as exit_stack:
ctx = exit_stack.enter_context(_get_query_context(factory, request.query))
spec = request.result_spec.to_result_spec(ctx.driver.universe)

# We write the response incrementally, one page at a time, as
# newline-separated chunks of JSON. This allows clients to start
# reading results earlier and prevents the server from exhausting
# all its memory buffering rows from large queries.
output_generator = _stream_query_pages(
# Transfer control of the context manager to
# _stream_query_pages.
exit_stack.pop_all(),
ctx,
spec,
)
return StreamingResponse(
output_generator,
media_type="application/jsonlines",
headers={
# Instruct the Kubernetes ingress to not buffer the response,
# so that keep-alives reach the client promptly.
"X-Accel-Buffering": "no"
},
)

# Mypy thinks that ExitStack might swallow an exception.
assert False, "This line is unreachable."


# Instead of declaring this as a sync generator with 'def', it's async to
# give us more control over the lifetime of exit_stack. StreamingResponse
# ensures that this async generator is cancelled if the client
# disconnects or another error occurs, ensuring that clean-up logic runs.
#
# If it was sync, it would get wrapped in an async function internal to
# FastAPI that does not guarantee that the generator is fully iterated or
# closed.
# (There is an example in the FastAPI docs showing StreamingResponse with a
# sync generator with a context manager, but after reading the FastAPI
# source code I believe that for sync generators it will leak the context
# manager if the client disconnects, and that it would be
# difficult/impossible for them to fix this in the general case within
# FastAPI.)
async def _stream_query_pages(
exit_stack: ExitStack, ctx: _QueryContext, spec: ResultSpec
) -> AsyncIterator[str]:
# We write the response incrementally, one page at a time, as
# newline-separated chunks of JSON. This allows clients to start
# reading results earlier and prevents the server from exhausting
# all its memory buffering rows from large queries.
output_generator = _stream_query_pages(request, factory)
return StreamingResponse(
output_generator,
media_type="application/jsonlines",
headers={
# Instruct the Kubernetes ingress to not buffer the response,
# so that keep-alives reach the client promptly.
"X-Accel-Buffering": "no"
},
)


async def _stream_query_pages(request: QueryExecuteRequestModel, factory: Factory) -> AsyncIterator[str]:
"""Stream the query output with one page object per line, as
newline-delimited JSON records in the "JSON Lines" format
(https://jsonlines.org/).
When it takes longer than 15 seconds to get a response from the DB,
sends a keep-alive message to prevent clients from timing out.
"""
# Ensure that the database connection is cleaned up by taking control of
# exit_stack.
async with contextmanager_in_threadpool(exit_stack):
# `None` signals that there is no more data to send.
queue = asyncio.Queue[QueryExecuteResultData | None](1)
async with asyncio.TaskGroup() as tg:
tg.create_task(_enqueue_query_pages(ctx, spec, queue))
async for message in _dequeue_query_pages_with_keepalive(queue):
yield message.model_dump_json() + "\n"


async def _enqueue_query_pages(ctx: _QueryContext, spec: ResultSpec, queue: asyncio.Queue) -> None:
async for page in iterate_in_threadpool(_retrieve_query_pages(ctx, spec)):
await queue.put(page)
# `None` signals that there is no more data to send.
queue = asyncio.Queue[QueryExecuteResultData | None](1)
async with asyncio.TaskGroup() as tg:
# Run a background task to read from the DB and insert the result pages
# into a queue.
tg.create_task(_enqueue_query_pages(queue, request, factory))
# Read the result pages from the queue and send them to the client,
# inserting a keep-alive message every 15 seconds if we are waiting a
# long time for the database.
async for message in _dequeue_query_pages_with_keepalive(queue):
yield message.model_dump_json() + "\n"


async def _enqueue_query_pages(
queue: asyncio.Queue[QueryExecuteResultData | None], request: QueryExecuteRequestModel, factory: Factory
) -> None:
"""Set up a QueryDriver to run the query, and copy the results into a
queue. Send `None` to the queue when there is no more data to read.
"""
try:
async with contextmanager_in_threadpool(_get_query_context(factory, request.query)) as ctx:
spec = request.result_spec.to_result_spec(ctx.driver.universe)
async for page in iterate_in_threadpool(_retrieve_query_pages(ctx, spec)):
await queue.put(page)
except ButlerUserError as e:
# If a user-facing error occurs, serialize it and send it to the
# client.
await queue.put(QueryErrorResultModel(error=serialize_butler_user_error(e)))

# Signal that there is no more data to read.
await queue.put(None)


def _retrieve_query_pages(ctx: _QueryContext, spec: ResultSpec) -> Iterator[QueryExecuteResultData]:
"""Execute the database query and and return pages of results."""
pages = ctx.driver.execute(spec, ctx.tree)
for page in pages:
yield convert_query_page(spec, page)


async def _dequeue_query_pages_with_keepalive(
queue: asyncio.Queue[QueryExecuteResultData | None],
) -> AsyncIterator[QueryExecuteResultData]:
Expand All @@ -162,18 +148,6 @@ async def _dequeue_query_pages_with_keepalive(
yield QueryKeepAliveModel()

Check warning on line 148 in python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py#L147-L148

Added lines #L147 - L148 were not covered by tests


def _retrieve_query_pages(ctx: _QueryContext, spec: ResultSpec) -> Iterator[QueryExecuteResultData]:
"""Execute the database query and and return pages of results."""
try:
pages = ctx.driver.execute(spec, ctx.tree)
for page in pages:
yield convert_query_page(spec, page)
except ButlerUserError as e:
# If a user-facing error occurs, serialize it and send it to the
# client.
yield QueryErrorResultModel(error=serialize_butler_user_error(e))


@query_router.post(
"/v1/query/count",
summary="Query the Butler database and return a count of rows that would be returned.",
Expand Down

0 comments on commit 3d29584

Please sign in to comment.