diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py index 3b5f627b61..8132702b44 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py @@ -31,7 +31,7 @@ import asyncio from collections.abc import AsyncIterator, Iterator -from contextlib import ExitStack, contextmanager +from contextlib import contextmanager from typing import NamedTuple from fastapi import APIRouter, Depends @@ -63,61 +63,26 @@ @query_router.post("/v1/query/execute", summary="Query the Butler database and return full results") -def query_execute( +async def query_execute( request: QueryExecuteRequestModel, factory: Factory = Depends(factory_dependency) ) -> StreamingResponse: - # Managing the lifetime of the query context object is a little tricky. We - # need to enter the context here, so that we can immediately deal with any - # exceptions raised by query set-up. We eventually transfer control to an - # iterator consumed by FastAPI's StreamingResponse handler, which will - # start iterating after this function returns. So we use this ExitStack - # instance to hand over the context manager to the iterator. - with ExitStack() as exit_stack: - ctx = exit_stack.enter_context(_get_query_context(factory, request.query)) - spec = request.result_spec.to_result_spec(ctx.driver.universe) - - # We write the response incrementally, one page at a time, as - # newline-separated chunks of JSON. This allows clients to start - # reading results earlier and prevents the server from exhausting - # all its memory buffering rows from large queries. - output_generator = _stream_query_pages( - # Transfer control of the context manager to - # _stream_query_pages. - exit_stack.pop_all(), - ctx, - spec, - ) - return StreamingResponse( - output_generator, - media_type="application/jsonlines", - headers={ - # Instruct the Kubernetes ingress to not buffer the response, - # so that keep-alives reach the client promptly. - "X-Accel-Buffering": "no" - }, - ) - - # Mypy thinks that ExitStack might swallow an exception. - assert False, "This line is unreachable." - - -# Instead of declaring this as a sync generator with 'def', it's async to -# give us more control over the lifetime of exit_stack. StreamingResponse -# ensures that this async generator is cancelled if the client -# disconnects or another error occurs, ensuring that clean-up logic runs. -# -# If it was sync, it would get wrapped in an async function internal to -# FastAPI that does not guarantee that the generator is fully iterated or -# closed. -# (There is an example in the FastAPI docs showing StreamingResponse with a -# sync generator with a context manager, but after reading the FastAPI -# source code I believe that for sync generators it will leak the context -# manager if the client disconnects, and that it would be -# difficult/impossible for them to fix this in the general case within -# FastAPI.) -async def _stream_query_pages( - exit_stack: ExitStack, ctx: _QueryContext, spec: ResultSpec -) -> AsyncIterator[str]: + # We write the response incrementally, one page at a time, as + # newline-separated chunks of JSON. This allows clients to start + # reading results earlier and prevents the server from exhausting + # all its memory buffering rows from large queries. + output_generator = _stream_query_pages(request, factory) + return StreamingResponse( + output_generator, + media_type="application/jsonlines", + headers={ + # Instruct the Kubernetes ingress to not buffer the response, + # so that keep-alives reach the client promptly. + "X-Accel-Buffering": "no" + }, + ) + + +async def _stream_query_pages(request: QueryExecuteRequestModel, factory: Factory) -> AsyncIterator[str]: """Stream the query output with one page object per line, as newline-delimited JSON records in the "JSON Lines" format (https://jsonlines.org/). @@ -125,25 +90,46 @@ async def _stream_query_pages( When it takes longer than 15 seconds to get a response from the DB, sends a keep-alive message to prevent clients from timing out. """ - # Ensure that the database connection is cleaned up by taking control of - # exit_stack. - async with contextmanager_in_threadpool(exit_stack): - # `None` signals that there is no more data to send. - queue = asyncio.Queue[QueryExecuteResultData | None](1) - async with asyncio.TaskGroup() as tg: - tg.create_task(_enqueue_query_pages(ctx, spec, queue)) - async for message in _dequeue_query_pages_with_keepalive(queue): - yield message.model_dump_json() + "\n" - - -async def _enqueue_query_pages(ctx: _QueryContext, spec: ResultSpec, queue: asyncio.Queue) -> None: - async for page in iterate_in_threadpool(_retrieve_query_pages(ctx, spec)): - await queue.put(page) + # `None` signals that there is no more data to send. + queue = asyncio.Queue[QueryExecuteResultData | None](1) + async with asyncio.TaskGroup() as tg: + # Run a background task to read from the DB and insert the result pages + # into a queue. + tg.create_task(_enqueue_query_pages(queue, request, factory)) + # Read the result pages from the queue and send them to the client, + # inserting a keep-alive message every 15 seconds if we are waiting a + # long time for the database. + async for message in _dequeue_query_pages_with_keepalive(queue): + yield message.model_dump_json() + "\n" + + +async def _enqueue_query_pages( + queue: asyncio.Queue[QueryExecuteResultData | None], request: QueryExecuteRequestModel, factory: Factory +) -> None: + """Set up a QueryDriver to run the query, and copy the results into a + queue. Send `None` to the queue when there is no more data to read. + """ + try: + async with contextmanager_in_threadpool(_get_query_context(factory, request.query)) as ctx: + spec = request.result_spec.to_result_spec(ctx.driver.universe) + async for page in iterate_in_threadpool(_retrieve_query_pages(ctx, spec)): + await queue.put(page) + except ButlerUserError as e: + # If a user-facing error occurs, serialize it and send it to the + # client. + await queue.put(QueryErrorResultModel(error=serialize_butler_user_error(e))) # Signal that there is no more data to read. await queue.put(None) +def _retrieve_query_pages(ctx: _QueryContext, spec: ResultSpec) -> Iterator[QueryExecuteResultData]: + """Execute the database query and and return pages of results.""" + pages = ctx.driver.execute(spec, ctx.tree) + for page in pages: + yield convert_query_page(spec, page) + + async def _dequeue_query_pages_with_keepalive( queue: asyncio.Queue[QueryExecuteResultData | None], ) -> AsyncIterator[QueryExecuteResultData]: @@ -162,18 +148,6 @@ async def _dequeue_query_pages_with_keepalive( yield QueryKeepAliveModel() -def _retrieve_query_pages(ctx: _QueryContext, spec: ResultSpec) -> Iterator[QueryExecuteResultData]: - """Execute the database query and and return pages of results.""" - try: - pages = ctx.driver.execute(spec, ctx.tree) - for page in pages: - yield convert_query_page(spec, page) - except ButlerUserError as e: - # If a user-facing error occurs, serialize it and send it to the - # client. - yield QueryErrorResultModel(error=serialize_butler_user_error(e)) - - @query_router.post( "/v1/query/count", summary="Query the Butler database and return a count of rows that would be returned.",