diff --git a/python/lsst/daf/butler/remote_butler/_factory.py b/python/lsst/daf/butler/remote_butler/_factory.py index b872d2fa86..de18a5b4e5 100644 --- a/python/lsst/daf/butler/remote_butler/_factory.py +++ b/python/lsst/daf/butler/remote_butler/_factory.py @@ -66,7 +66,18 @@ def __init__(self, server_url: str, http_client: httpx.Client | None = None): if http_client is not None: self.http_client = http_client else: - self.http_client = httpx.Client() + self.http_client = httpx.Client( + # This timeout is fairly conservative. This value isn't the + # maximum amount of time the request can take -- it's the + # maximum amount of time to wait after receiving the last chunk + # of data from the server. + # + # Long-running, streamed queries send a keep-alive every 15 + # seconds. However, unstreamed operations like + # queryCollections can potentially take a while if the database + # is under duress. + timeout=120 # seconds + ) self._cache = RemoteButlerCache() @staticmethod diff --git a/python/lsst/daf/butler/remote_butler/_query_driver.py b/python/lsst/daf/butler/remote_butler/_query_driver.py index 5cd2d0a793..8bc9097cbf 100644 --- a/python/lsst/daf/butler/remote_butler/_query_driver.py +++ b/python/lsst/daf/butler/remote_butler/_query_driver.py @@ -147,8 +147,11 @@ def execute(self, result_spec: ResultSpec, tree: QueryTree) -> Iterator[ResultPa # There is one result page JSON object per line of the # response. for line in response.iter_lines(): - result_chunk = _QueryResultTypeAdapter.validate_json(line) - yield _convert_query_result_page(result_spec, result_chunk, universe) + result_chunk: QueryExecuteResultData = _QueryResultTypeAdapter.validate_json(line) + if result_chunk.type == "keep-alive": + _received_keep_alive() + else: + yield _convert_query_result_page(result_spec, result_chunk, universe) if self._closed: raise RuntimeError( "Cannot continue query result iteration: query context has been closed" @@ -279,3 +282,10 @@ def _convert_general_result(spec: GeneralResultSpec, model: GeneralResultModel) for row in model.rows ] return GeneralResultPage(spec=spec, rows=rows) + + +def _received_keep_alive() -> None: + """Do nothing. Gives a place for unit tests to hook in for testing + keepalive behavior. + """ + pass diff --git a/python/lsst/daf/butler/remote_butler/server/_server.py b/python/lsst/daf/butler/remote_butler/server/_server.py index 9c72f520d5..edf9dea405 100644 --- a/python/lsst/daf/butler/remote_butler/server/_server.py +++ b/python/lsst/daf/butler/remote_butler/server/_server.py @@ -33,7 +33,6 @@ import safir.dependencies.logger from fastapi import FastAPI, Request, Response -from fastapi.middleware.gzip import GZipMiddleware from fastapi.staticfiles import StaticFiles from safir.logging import configure_logging, configure_uvicorn_logging @@ -54,7 +53,6 @@ def create_app() -> FastAPI: config = load_config() app = FastAPI() - app.add_middleware(GZipMiddleware, minimum_size=1000) # A single instance of the server can serve data from multiple Butler # repositories. This 'repository' path placeholder is consumed by diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py index 909088e198..56767f29a2 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py @@ -29,8 +29,9 @@ __all__ = ("query_router",) -from collections.abc import AsyncIterator, Iterable, Iterator -from contextlib import ExitStack, contextmanager +import asyncio +from collections.abc import AsyncIterator, Iterator +from contextlib import contextmanager from typing import NamedTuple from fastapi import APIRouter, Depends @@ -42,72 +43,112 @@ QueryAnyResponseModel, QueryCountRequestModel, QueryCountResponseModel, + QueryErrorResultModel, QueryExecuteRequestModel, + QueryExecuteResultData, QueryExplainRequestModel, QueryExplainResponseModel, QueryInputs, + QueryKeepAliveModel, ) -from ....queries.driver import QueryDriver, QueryTree, ResultPage, ResultSpec +from ...._exceptions import ButlerUserError +from ....queries.driver import QueryDriver, QueryTree, ResultSpec +from ..._errors import serialize_butler_user_error from .._dependencies import factory_dependency from .._factory import Factory -from ._query_serialization import serialize_query_pages +from ._query_serialization import convert_query_page query_router = APIRouter() +# Alias this function so we can mock it during unit tests. +_timeout = asyncio.timeout + @query_router.post("/v1/query/execute", summary="Query the Butler database and return full results") -def query_execute( +async def query_execute( request: QueryExecuteRequestModel, factory: Factory = Depends(factory_dependency) ) -> StreamingResponse: - # Managing the lifetime of the query context object is a little tricky. We - # need to enter the context here, so that we can immediately deal with any - # exceptions raised by query set-up. We eventually transfer control to an - # iterator consumed by FastAPI's StreamingResponse handler, which will - # start iterating after this function returns. So we use this ExitStack - # instance to hand over the context manager to the iterator. - with ExitStack() as exit_stack: - ctx = exit_stack.enter_context(_get_query_context(factory, request.query)) - spec = request.result_spec.to_result_spec(ctx.driver.universe) - response_pages = ctx.driver.execute(spec, ctx.tree) - - # We write the response incrementally, one page at a time, as - # newline-separated chunks of JSON. This allows clients to start - # reading results earlier and prevents the server from exhausting - # all its memory buffering rows from large queries. - output_generator = _stream_query_pages( - # Transfer control of the context manager to - # _stream_query_pages. - exit_stack.pop_all(), - spec, - response_pages, - ) - return StreamingResponse(output_generator, media_type="application/jsonlines") - - # Mypy thinks that ExitStack might swallow an exception. - assert False, "This line is unreachable." - - -async def _stream_query_pages( - exit_stack: ExitStack, spec: ResultSpec, pages: Iterable[ResultPage] -) -> AsyncIterator[str]: - # Instead of declaring this as a sync generator with 'def', it's async to - # give us more control over the lifetime of exit_stack. StreamingResponse - # ensures that this async generator is cancelled if the client - # disconnects or another error occurs, ensuring that clean-up logic runs. - # - # If it was sync, it would get wrapped in an async function internal to - # FastAPI that does not guarantee that the generator is fully iterated or - # closed. - # (There is an example in the FastAPI docs showing StreamingResponse with a - # sync generator with a context manager, but after reading the FastAPI - # source code I believe that for sync generators it will leak the context - # manager if the client disconnects, and that it would be - # difficult/impossible for them to fix this in the general case within - # FastAPI.) - async with contextmanager_in_threadpool(exit_stack): - async for chunk in iterate_in_threadpool(serialize_query_pages(spec, pages)): - yield chunk + # We write the response incrementally, one page at a time, as + # newline-separated chunks of JSON. This allows clients to start + # reading results earlier and prevents the server from exhausting + # all its memory buffering rows from large queries. + output_generator = _stream_query_pages(request, factory) + return StreamingResponse( + output_generator, + media_type="application/jsonlines", + headers={ + # Instruct the Kubernetes ingress to not buffer the response, + # so that keep-alives reach the client promptly. + "X-Accel-Buffering": "no" + }, + ) + + +async def _stream_query_pages(request: QueryExecuteRequestModel, factory: Factory) -> AsyncIterator[str]: + """Stream the query output with one page object per line, as + newline-delimited JSON records in the "JSON Lines" format + (https://jsonlines.org/). + + When it takes longer than 15 seconds to get a response from the DB, + sends a keep-alive message to prevent clients from timing out. + """ + # `None` signals that there is no more data to send. + queue = asyncio.Queue[QueryExecuteResultData | None](1) + async with asyncio.TaskGroup() as tg: + # Run a background task to read from the DB and insert the result pages + # into a queue. + tg.create_task(_enqueue_query_pages(queue, request, factory)) + # Read the result pages from the queue and send them to the client, + # inserting a keep-alive message every 15 seconds if we are waiting a + # long time for the database. + async for message in _dequeue_query_pages_with_keepalive(queue): + yield message.model_dump_json() + "\n" + + +async def _enqueue_query_pages( + queue: asyncio.Queue[QueryExecuteResultData | None], request: QueryExecuteRequestModel, factory: Factory +) -> None: + """Set up a QueryDriver to run the query, and copy the results into a + queue. Send `None` to the queue when there is no more data to read. + """ + try: + async with contextmanager_in_threadpool(_get_query_context(factory, request.query)) as ctx: + spec = request.result_spec.to_result_spec(ctx.driver.universe) + async for page in iterate_in_threadpool(_retrieve_query_pages(ctx, spec)): + await queue.put(page) + except ButlerUserError as e: + # If a user-facing error occurs, serialize it and send it to the + # client. + await queue.put(QueryErrorResultModel(error=serialize_butler_user_error(e))) + + # Signal that there is no more data to read. + await queue.put(None) + + +def _retrieve_query_pages(ctx: _QueryContext, spec: ResultSpec) -> Iterator[QueryExecuteResultData]: + """Execute the database query and and return pages of results.""" + pages = ctx.driver.execute(spec, ctx.tree) + for page in pages: + yield convert_query_page(spec, page) + + +async def _dequeue_query_pages_with_keepalive( + queue: asyncio.Queue[QueryExecuteResultData | None], +) -> AsyncIterator[QueryExecuteResultData]: + """Read and return messages from the given queue until the end-of-stream + message `None` is reached. If the producer is taking a long time, returns + a keep-alive message every 15 seconds while we are waiting. + """ + while True: + try: + async with _timeout(15): + message = await queue.get() + if message is None: + return + yield message + except TimeoutError: + yield QueryKeepAliveModel() @query_router.post( diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py b/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py index e2db3399d5..5cdc4cd796 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py @@ -27,9 +27,6 @@ from __future__ import annotations -from collections.abc import Iterable, Iterator - -from ...._exceptions import ButlerUserError from ....queries.driver import ( DataCoordinateResultPage, DatasetRefResultPage, @@ -38,36 +35,16 @@ ResultPage, ResultSpec, ) -from ..._errors import serialize_butler_user_error from ...server_models import ( DataCoordinateResultModel, DatasetRefResultModel, DimensionRecordsResultModel, GeneralResultModel, - QueryErrorResultModel, QueryExecuteResultData, ) -def serialize_query_pages( - spec: ResultSpec, pages: Iterable[ResultPage] -) -> Iterator[str]: # numpydoc ignore=PR01 - """Serialize result pages to pages of result data in JSON format. The - output contains one page object per line, as newline-delimited JSON records - in the "JSON Lines" format (https://jsonlines.org/). - """ - try: - for page in pages: - yield _convert_query_page(spec, page).model_dump_json() - yield "\n" - except ButlerUserError as e: - # If a user-facing error occurs, serialize it and send it to the - # client. - yield QueryErrorResultModel(error=serialize_butler_user_error(e)).model_dump_json() - yield "\n" - - -def _convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResultData: +def convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResultData: """Convert pages of result data from the query system to a serializable format. @@ -75,8 +52,13 @@ def _convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResul ---------- spec : `ResultSpec` Definition of the output format for the results. - pages : `ResultPage` + page : `ResultPage` Raw page of data from the query driver. + + Returns + ------- + model : `QueryExecuteResultData` + Serializable pydantic model version of the page. """ match spec.result_type: case "dimension_record": diff --git a/python/lsst/daf/butler/remote_butler/server_models.py b/python/lsst/daf/butler/remote_butler/server_models.py index 876bfe9839..5ae692506f 100644 --- a/python/lsst/daf/butler/remote_butler/server_models.py +++ b/python/lsst/daf/butler/remote_butler/server_models.py @@ -299,12 +299,24 @@ class QueryErrorResultModel(pydantic.BaseModel): error: ErrorResponseModel +class QueryKeepAliveModel(pydantic.BaseModel): + """Result model for /query/execute used to keep connection alive. + + Some queries require a significant start-up time before they can start + returning results, or a long processing time for each chunk of rows. This + message signals that the server is still fetching the data. + """ + + type: Literal["keep-alive"] = "keep-alive" + + QueryExecuteResultData: TypeAlias = Annotated[ DataCoordinateResultModel | DimensionRecordsResultModel | DatasetRefResultModel | GeneralResultModel - | QueryErrorResultModel, + | QueryErrorResultModel + | QueryKeepAliveModel, pydantic.Field(discriminator="type"), ] diff --git a/tests/test_server.py b/tests/test_server.py index c6390a203c..33cae91c8c 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -34,6 +34,8 @@ try: # Failing to import any of these should disable the tests. + import lsst.daf.butler.remote_butler._query_driver + import lsst.daf.butler.remote_butler.server.handlers._external_query import safir.dependencies.logger from fastapi.testclient import TestClient from lsst.daf.butler.remote_butler import RemoteButler @@ -47,7 +49,7 @@ create_test_server = None reason_text = str(e) -from unittest.mock import NonCallableMock, patch +from unittest.mock import DEFAULT, NonCallableMock, patch from lsst.daf.butler import ( Butler, @@ -402,6 +404,28 @@ async def get_logger(): self.assertEqual(kwargs["clientRequestId"], "request-id") self.assertEqual(kwargs["user"], "user-name") + def test_query_keepalive(self): + """Test that long-running queries stream keep-alive messages to stop + the HTTP connection from closing before they are able to return + results. + """ + # Normally it takes 15 seconds for a timeout -- mock it to trigger + # immediately instead. + with patch.object( + lsst.daf.butler.remote_butler.server.handlers._external_query, "_timeout" + ) as mock_timeout: + # Hook into QueryDriver to track the number of keep-alives we have + # seen. + with patch.object( + lsst.daf.butler.remote_butler._query_driver, "_received_keep_alive" + ) as mock_keep_alive: + mock_timeout.side_effect = _timeout_twice() + with self.butler._query() as query: + datasets = list(query.datasets("bias", "imported_g")) + self.assertEqual(len(datasets), 3) + self.assertGreaterEqual(mock_timeout.call_count, 3) + self.assertGreaterEqual(mock_keep_alive.call_count, 2) + def _create_corrupted_dataset(repo: MetricTestRepo) -> DatasetRef: run = "corrupted-run" @@ -418,5 +442,21 @@ def _create_simple_dataset(butler: Butler) -> DatasetRef: return ref +def _timeout_twice(): + """Return a mock side-effect function that raises a timeout error the first + two times it is called. + """ + count = 0 + + def timeout(*args): + nonlocal count + count += 1 + if count <= 2: + raise TimeoutError() + return DEFAULT + + return timeout + + if __name__ == "__main__": unittest.main()