Skip to content

Commit

Permalink
Add keep-alive messages to long-running server queries
Browse files Browse the repository at this point in the history
Fetching database results for Butler queries can take an unpredictable amount of time, in some cases several minutes.  To ensure that clients do not time out while waiting, send a keep-alive message to the stream every fifteen seconds.
  • Loading branch information
dhirving committed Aug 29, 2024
1 parent 595115b commit 311fece
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 51 deletions.
5 changes: 3 additions & 2 deletions python/lsst/daf/butler/remote_butler/_query_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,9 @@ def execute(self, result_spec: ResultSpec, tree: QueryTree) -> Iterator[ResultPa
# There is one result page JSON object per line of the
# response.
for line in response.iter_lines():
result_chunk = _QueryResultTypeAdapter.validate_json(line)
yield _convert_query_result_page(result_spec, result_chunk, universe)
result_chunk: QueryExecuteResultData = _QueryResultTypeAdapter.validate_json(line)
if result_chunk.type != "keep-alive":
yield _convert_query_result_page(result_spec, result_chunk, universe)
if self._closed:
raise RuntimeError(
"Cannot continue query result iteration: query context has been closed"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@

__all__ = ("query_router",)

from collections.abc import AsyncIterator, Iterable, Iterator
import asyncio
from collections.abc import AsyncIterator, Iterator
from contextlib import ExitStack, contextmanager
from typing import NamedTuple

Expand All @@ -42,16 +43,21 @@
QueryAnyResponseModel,
QueryCountRequestModel,
QueryCountResponseModel,
QueryErrorResultModel,
QueryExecuteRequestModel,
QueryExecuteResultData,
QueryExplainRequestModel,
QueryExplainResponseModel,
QueryInputs,
QueryKeepAliveModel,
)

from ....queries.driver import QueryDriver, QueryTree, ResultPage, ResultSpec
from ...._exceptions import ButlerUserError
from ....queries.driver import QueryDriver, QueryTree, ResultSpec
from ..._errors import serialize_butler_user_error
from .._dependencies import factory_dependency
from .._factory import Factory
from ._query_serialization import serialize_query_pages
from ._query_serialization import convert_query_page

query_router = APIRouter()

Expand All @@ -69,7 +75,6 @@ def query_execute(
with ExitStack() as exit_stack:
ctx = exit_stack.enter_context(_get_query_context(factory, request.query))
spec = request.result_spec.to_result_spec(ctx.driver.universe)
response_pages = ctx.driver.execute(spec, ctx.tree)

# We write the response incrementally, one page at a time, as
# newline-separated chunks of JSON. This allows clients to start
Expand All @@ -79,35 +84,94 @@ def query_execute(
# Transfer control of the context manager to
# _stream_query_pages.
exit_stack.pop_all(),
ctx,
spec,
response_pages,
)
return StreamingResponse(output_generator, media_type="application/jsonlines")
return StreamingResponse(
output_generator,
media_type="application/jsonlines",
headers={
# Instruct the Kubernetes ingress to not buffer the response,
# so that keep-alives reach the client promptly.
"X-Accel-Buffering": "no"
},
)

# Mypy thinks that ExitStack might swallow an exception.
assert False, "This line is unreachable."


# Instead of declaring this as a sync generator with 'def', it's async to
# give us more control over the lifetime of exit_stack. StreamingResponse
# ensures that this async generator is cancelled if the client
# disconnects or another error occurs, ensuring that clean-up logic runs.
#
# If it was sync, it would get wrapped in an async function internal to
# FastAPI that does not guarantee that the generator is fully iterated or
# closed.
# (There is an example in the FastAPI docs showing StreamingResponse with a
# sync generator with a context manager, but after reading the FastAPI
# source code I believe that for sync generators it will leak the context
# manager if the client disconnects, and that it would be
# difficult/impossible for them to fix this in the general case within
# FastAPI.)
async def _stream_query_pages(
exit_stack: ExitStack, spec: ResultSpec, pages: Iterable[ResultPage]
exit_stack: ExitStack, ctx: _QueryContext, spec: ResultSpec
) -> AsyncIterator[str]:
# Instead of declaring this as a sync generator with 'def', it's async to
# give us more control over the lifetime of exit_stack. StreamingResponse
# ensures that this async generator is cancelled if the client
# disconnects or another error occurs, ensuring that clean-up logic runs.
#
# If it was sync, it would get wrapped in an async function internal to
# FastAPI that does not guarantee that the generator is fully iterated or
# closed.
# (There is an example in the FastAPI docs showing StreamingResponse with a
# sync generator with a context manager, but after reading the FastAPI
# source code I believe that for sync generators it will leak the context
# manager if the client disconnects, and that it would be
# difficult/impossible for them to fix this in the general case within
# FastAPI.)
"""Stream the query output with one page object per line, as
newline-delimited JSON records in the "JSON Lines" format
(https://jsonlines.org/).
When it takes longer than 15 seconds to get a response from the DB,
sends a keep-alive message to prevent clients from timing out.
"""
# Ensure that the database connection is cleaned up by taking control of
# exit_stack.
async with contextmanager_in_threadpool(exit_stack):
async for chunk in iterate_in_threadpool(serialize_query_pages(spec, pages)):
yield chunk
# `None` signals that there is no more data to send.
queue = asyncio.Queue[QueryExecuteResultData | None](1)
async with asyncio.TaskGroup() as tg:
tg.create_task(_enqueue_query_pages(ctx, spec, queue))
async for message in _dequeue_query_pages_with_keepalive(queue):
yield message.model_dump_json() + "\n"


async def _enqueue_query_pages(ctx: _QueryContext, spec: ResultSpec, queue: asyncio.Queue) -> None:
async for page in iterate_in_threadpool(_retrieve_query_pages(ctx, spec)):
await queue.put(page)

# Signal that there is no more data to read.
await queue.put(None)


async def _dequeue_query_pages_with_keepalive(
queue: asyncio.Queue[QueryExecuteResultData | None],
) -> AsyncIterator[QueryExecuteResultData]:
"""Read and return messages from the given queue until the end-of-stream
message `None` is reached. If the producer is taking a long time, returns
a keep-alive message every 15 seconds while we are waiting.
"""
while True:
try:
async with asyncio.timeout(15):
message = await queue.get()
if message is None:
return
yield message
except TimeoutError:
yield QueryKeepAliveModel()


def _retrieve_query_pages(ctx: _QueryContext, spec: ResultSpec) -> Iterator[QueryExecuteResultData]:
"""Execute the database query and and return pages of results."""
try:
pages = ctx.driver.execute(spec, ctx.tree)
for page in pages:
yield convert_query_page(spec, page)
except ButlerUserError as e:
# If a user-facing error occurs, serialize it and send it to the
# client.
yield QueryErrorResultModel(error=serialize_butler_user_error(e))


@query_router.post(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@

from __future__ import annotations

from collections.abc import Iterable, Iterator

from ...._exceptions import ButlerUserError
from ....queries.driver import (
DataCoordinateResultPage,
DatasetRefResultPage,
Expand All @@ -38,45 +35,30 @@
ResultPage,
ResultSpec,
)
from ..._errors import serialize_butler_user_error
from ...server_models import (
DataCoordinateResultModel,
DatasetRefResultModel,
DimensionRecordsResultModel,
GeneralResultModel,
QueryErrorResultModel,
QueryExecuteResultData,
)


def serialize_query_pages(
spec: ResultSpec, pages: Iterable[ResultPage]
) -> Iterator[str]: # numpydoc ignore=PR01
"""Serialize result pages to pages of result data in JSON format. The
output contains one page object per line, as newline-delimited JSON records
in the "JSON Lines" format (https://jsonlines.org/).
"""
try:
for page in pages:
yield _convert_query_page(spec, page).model_dump_json()
yield "\n"
except ButlerUserError as e:
# If a user-facing error occurs, serialize it and send it to the
# client.
yield QueryErrorResultModel(error=serialize_butler_user_error(e)).model_dump_json()
yield "\n"


def _convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResultData:
def convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResultData:
"""Convert pages of result data from the query system to a serializable
format.
Parameters
----------
spec : `ResultSpec`
Definition of the output format for the results.
pages : `ResultPage`
page : `ResultPage`
Raw page of data from the query driver.
Returns
-------
model : `QueryExecuteResultData`
Serializable pydantic model version of the page.
"""
match spec.result_type:
case "dimension_record":
Expand Down
14 changes: 13 additions & 1 deletion python/lsst/daf/butler/remote_butler/server_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,24 @@ class QueryErrorResultModel(pydantic.BaseModel):
error: ErrorResponseModel


class QueryKeepAliveModel(pydantic.BaseModel):
"""Result model for /query/execute used to keep connection alive.
Some queries require a significant start-up time before they can start
returning results, or a long processing time for each chunk of rows. This
message signals that the server is still fetching the data.
"""

type: Literal["keep-alive"] = "keep-alive"


QueryExecuteResultData: TypeAlias = Annotated[
DataCoordinateResultModel
| DimensionRecordsResultModel
| DatasetRefResultModel
| GeneralResultModel
| QueryErrorResultModel,
| QueryErrorResultModel
| QueryKeepAliveModel,
pydantic.Field(discriminator="type"),
]

Expand Down

0 comments on commit 311fece

Please sign in to comment.