From 595115b1f64fe7591fcc9af90825c4682825554d Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Wed, 28 Aug 2024 17:27:12 -0700 Subject: [PATCH 1/5] Increase timeout for RemoteButler HTTP requests The default HTTPX timeout of 5 seconds is much too low when communicating with Butler server -- many requests can exceed that while waiting for the database to respond. --- python/lsst/daf/butler/remote_butler/_factory.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/python/lsst/daf/butler/remote_butler/_factory.py b/python/lsst/daf/butler/remote_butler/_factory.py index b872d2fa86..de18a5b4e5 100644 --- a/python/lsst/daf/butler/remote_butler/_factory.py +++ b/python/lsst/daf/butler/remote_butler/_factory.py @@ -66,7 +66,18 @@ def __init__(self, server_url: str, http_client: httpx.Client | None = None): if http_client is not None: self.http_client = http_client else: - self.http_client = httpx.Client() + self.http_client = httpx.Client( + # This timeout is fairly conservative. This value isn't the + # maximum amount of time the request can take -- it's the + # maximum amount of time to wait after receiving the last chunk + # of data from the server. + # + # Long-running, streamed queries send a keep-alive every 15 + # seconds. However, unstreamed operations like + # queryCollections can potentially take a while if the database + # is under duress. + timeout=120 # seconds + ) self._cache = RemoteButlerCache() @staticmethod From 311feceaf4bf142177d9439ddca1ff084c5ea871 Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Wed, 28 Aug 2024 17:25:36 -0700 Subject: [PATCH 2/5] Add keep-alive messages to long-running server queries Fetching database results for Butler queries can take an unpredictable amount of time, in some cases several minutes. To ensure that clients do not time out while waiting, send a keep-alive message to the stream every fifteen seconds. --- .../daf/butler/remote_butler/_query_driver.py | 5 +- .../server/handlers/_external_query.py | 110 ++++++++++++++---- .../server/handlers/_query_serialization.py | 32 ++--- .../daf/butler/remote_butler/server_models.py | 14 ++- 4 files changed, 110 insertions(+), 51 deletions(-) diff --git a/python/lsst/daf/butler/remote_butler/_query_driver.py b/python/lsst/daf/butler/remote_butler/_query_driver.py index 5cd2d0a793..d290e62ffe 100644 --- a/python/lsst/daf/butler/remote_butler/_query_driver.py +++ b/python/lsst/daf/butler/remote_butler/_query_driver.py @@ -147,8 +147,9 @@ def execute(self, result_spec: ResultSpec, tree: QueryTree) -> Iterator[ResultPa # There is one result page JSON object per line of the # response. for line in response.iter_lines(): - result_chunk = _QueryResultTypeAdapter.validate_json(line) - yield _convert_query_result_page(result_spec, result_chunk, universe) + result_chunk: QueryExecuteResultData = _QueryResultTypeAdapter.validate_json(line) + if result_chunk.type != "keep-alive": + yield _convert_query_result_page(result_spec, result_chunk, universe) if self._closed: raise RuntimeError( "Cannot continue query result iteration: query context has been closed" diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py index 909088e198..3b5f627b61 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py @@ -29,7 +29,8 @@ __all__ = ("query_router",) -from collections.abc import AsyncIterator, Iterable, Iterator +import asyncio +from collections.abc import AsyncIterator, Iterator from contextlib import ExitStack, contextmanager from typing import NamedTuple @@ -42,16 +43,21 @@ QueryAnyResponseModel, QueryCountRequestModel, QueryCountResponseModel, + QueryErrorResultModel, QueryExecuteRequestModel, + QueryExecuteResultData, QueryExplainRequestModel, QueryExplainResponseModel, QueryInputs, + QueryKeepAliveModel, ) -from ....queries.driver import QueryDriver, QueryTree, ResultPage, ResultSpec +from ...._exceptions import ButlerUserError +from ....queries.driver import QueryDriver, QueryTree, ResultSpec +from ..._errors import serialize_butler_user_error from .._dependencies import factory_dependency from .._factory import Factory -from ._query_serialization import serialize_query_pages +from ._query_serialization import convert_query_page query_router = APIRouter() @@ -69,7 +75,6 @@ def query_execute( with ExitStack() as exit_stack: ctx = exit_stack.enter_context(_get_query_context(factory, request.query)) spec = request.result_spec.to_result_spec(ctx.driver.universe) - response_pages = ctx.driver.execute(spec, ctx.tree) # We write the response incrementally, one page at a time, as # newline-separated chunks of JSON. This allows clients to start @@ -79,35 +84,94 @@ def query_execute( # Transfer control of the context manager to # _stream_query_pages. exit_stack.pop_all(), + ctx, spec, - response_pages, ) - return StreamingResponse(output_generator, media_type="application/jsonlines") + return StreamingResponse( + output_generator, + media_type="application/jsonlines", + headers={ + # Instruct the Kubernetes ingress to not buffer the response, + # so that keep-alives reach the client promptly. + "X-Accel-Buffering": "no" + }, + ) # Mypy thinks that ExitStack might swallow an exception. assert False, "This line is unreachable." +# Instead of declaring this as a sync generator with 'def', it's async to +# give us more control over the lifetime of exit_stack. StreamingResponse +# ensures that this async generator is cancelled if the client +# disconnects or another error occurs, ensuring that clean-up logic runs. +# +# If it was sync, it would get wrapped in an async function internal to +# FastAPI that does not guarantee that the generator is fully iterated or +# closed. +# (There is an example in the FastAPI docs showing StreamingResponse with a +# sync generator with a context manager, but after reading the FastAPI +# source code I believe that for sync generators it will leak the context +# manager if the client disconnects, and that it would be +# difficult/impossible for them to fix this in the general case within +# FastAPI.) async def _stream_query_pages( - exit_stack: ExitStack, spec: ResultSpec, pages: Iterable[ResultPage] + exit_stack: ExitStack, ctx: _QueryContext, spec: ResultSpec ) -> AsyncIterator[str]: - # Instead of declaring this as a sync generator with 'def', it's async to - # give us more control over the lifetime of exit_stack. StreamingResponse - # ensures that this async generator is cancelled if the client - # disconnects or another error occurs, ensuring that clean-up logic runs. - # - # If it was sync, it would get wrapped in an async function internal to - # FastAPI that does not guarantee that the generator is fully iterated or - # closed. - # (There is an example in the FastAPI docs showing StreamingResponse with a - # sync generator with a context manager, but after reading the FastAPI - # source code I believe that for sync generators it will leak the context - # manager if the client disconnects, and that it would be - # difficult/impossible for them to fix this in the general case within - # FastAPI.) + """Stream the query output with one page object per line, as + newline-delimited JSON records in the "JSON Lines" format + (https://jsonlines.org/). + + When it takes longer than 15 seconds to get a response from the DB, + sends a keep-alive message to prevent clients from timing out. + """ + # Ensure that the database connection is cleaned up by taking control of + # exit_stack. async with contextmanager_in_threadpool(exit_stack): - async for chunk in iterate_in_threadpool(serialize_query_pages(spec, pages)): - yield chunk + # `None` signals that there is no more data to send. + queue = asyncio.Queue[QueryExecuteResultData | None](1) + async with asyncio.TaskGroup() as tg: + tg.create_task(_enqueue_query_pages(ctx, spec, queue)) + async for message in _dequeue_query_pages_with_keepalive(queue): + yield message.model_dump_json() + "\n" + + +async def _enqueue_query_pages(ctx: _QueryContext, spec: ResultSpec, queue: asyncio.Queue) -> None: + async for page in iterate_in_threadpool(_retrieve_query_pages(ctx, spec)): + await queue.put(page) + + # Signal that there is no more data to read. + await queue.put(None) + + +async def _dequeue_query_pages_with_keepalive( + queue: asyncio.Queue[QueryExecuteResultData | None], +) -> AsyncIterator[QueryExecuteResultData]: + """Read and return messages from the given queue until the end-of-stream + message `None` is reached. If the producer is taking a long time, returns + a keep-alive message every 15 seconds while we are waiting. + """ + while True: + try: + async with asyncio.timeout(15): + message = await queue.get() + if message is None: + return + yield message + except TimeoutError: + yield QueryKeepAliveModel() + + +def _retrieve_query_pages(ctx: _QueryContext, spec: ResultSpec) -> Iterator[QueryExecuteResultData]: + """Execute the database query and and return pages of results.""" + try: + pages = ctx.driver.execute(spec, ctx.tree) + for page in pages: + yield convert_query_page(spec, page) + except ButlerUserError as e: + # If a user-facing error occurs, serialize it and send it to the + # client. + yield QueryErrorResultModel(error=serialize_butler_user_error(e)) @query_router.post( diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py b/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py index e2db3399d5..5cdc4cd796 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py @@ -27,9 +27,6 @@ from __future__ import annotations -from collections.abc import Iterable, Iterator - -from ...._exceptions import ButlerUserError from ....queries.driver import ( DataCoordinateResultPage, DatasetRefResultPage, @@ -38,36 +35,16 @@ ResultPage, ResultSpec, ) -from ..._errors import serialize_butler_user_error from ...server_models import ( DataCoordinateResultModel, DatasetRefResultModel, DimensionRecordsResultModel, GeneralResultModel, - QueryErrorResultModel, QueryExecuteResultData, ) -def serialize_query_pages( - spec: ResultSpec, pages: Iterable[ResultPage] -) -> Iterator[str]: # numpydoc ignore=PR01 - """Serialize result pages to pages of result data in JSON format. The - output contains one page object per line, as newline-delimited JSON records - in the "JSON Lines" format (https://jsonlines.org/). - """ - try: - for page in pages: - yield _convert_query_page(spec, page).model_dump_json() - yield "\n" - except ButlerUserError as e: - # If a user-facing error occurs, serialize it and send it to the - # client. - yield QueryErrorResultModel(error=serialize_butler_user_error(e)).model_dump_json() - yield "\n" - - -def _convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResultData: +def convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResultData: """Convert pages of result data from the query system to a serializable format. @@ -75,8 +52,13 @@ def _convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResul ---------- spec : `ResultSpec` Definition of the output format for the results. - pages : `ResultPage` + page : `ResultPage` Raw page of data from the query driver. + + Returns + ------- + model : `QueryExecuteResultData` + Serializable pydantic model version of the page. """ match spec.result_type: case "dimension_record": diff --git a/python/lsst/daf/butler/remote_butler/server_models.py b/python/lsst/daf/butler/remote_butler/server_models.py index 876bfe9839..5ae692506f 100644 --- a/python/lsst/daf/butler/remote_butler/server_models.py +++ b/python/lsst/daf/butler/remote_butler/server_models.py @@ -299,12 +299,24 @@ class QueryErrorResultModel(pydantic.BaseModel): error: ErrorResponseModel +class QueryKeepAliveModel(pydantic.BaseModel): + """Result model for /query/execute used to keep connection alive. + + Some queries require a significant start-up time before they can start + returning results, or a long processing time for each chunk of rows. This + message signals that the server is still fetching the data. + """ + + type: Literal["keep-alive"] = "keep-alive" + + QueryExecuteResultData: TypeAlias = Annotated[ DataCoordinateResultModel | DimensionRecordsResultModel | DatasetRefResultModel | GeneralResultModel - | QueryErrorResultModel, + | QueryErrorResultModel + | QueryKeepAliveModel, pydantic.Field(discriminator="type"), ] From 670ba4d1b7d3df751c3d2d857e17837239818538 Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Thu, 29 Aug 2024 15:28:54 -0700 Subject: [PATCH 3/5] Remove GZIP middleware from Butler server Remove the GzipMiddleware for two reasons: 1. It was preventing "keep-alive" messages from working in streamed responses, because it was batching them until it had a full gzip chunk. 2. It was slowing down queries by about 10% at the RSP -- it turns out that compressing/decompressing is expensive, and there is lots of bandwidth available within Google. --- python/lsst/daf/butler/remote_butler/server/_server.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/lsst/daf/butler/remote_butler/server/_server.py b/python/lsst/daf/butler/remote_butler/server/_server.py index 9c72f520d5..edf9dea405 100644 --- a/python/lsst/daf/butler/remote_butler/server/_server.py +++ b/python/lsst/daf/butler/remote_butler/server/_server.py @@ -33,7 +33,6 @@ import safir.dependencies.logger from fastapi import FastAPI, Request, Response -from fastapi.middleware.gzip import GZipMiddleware from fastapi.staticfiles import StaticFiles from safir.logging import configure_logging, configure_uvicorn_logging @@ -54,7 +53,6 @@ def create_app() -> FastAPI: config = load_config() app = FastAPI() - app.add_middleware(GZipMiddleware, minimum_size=1000) # A single instance of the server can serve data from multiple Butler # repositories. This 'repository' path placeholder is consumed by From 3d29584820e60b0088e08fa743a6ed528778dbcb Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Thu, 29 Aug 2024 16:25:24 -0700 Subject: [PATCH 4/5] Simplify query resource management in server --- .../server/handlers/_external_query.py | 134 +++++++----------- 1 file changed, 54 insertions(+), 80 deletions(-) diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py index 3b5f627b61..8132702b44 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py @@ -31,7 +31,7 @@ import asyncio from collections.abc import AsyncIterator, Iterator -from contextlib import ExitStack, contextmanager +from contextlib import contextmanager from typing import NamedTuple from fastapi import APIRouter, Depends @@ -63,61 +63,26 @@ @query_router.post("/v1/query/execute", summary="Query the Butler database and return full results") -def query_execute( +async def query_execute( request: QueryExecuteRequestModel, factory: Factory = Depends(factory_dependency) ) -> StreamingResponse: - # Managing the lifetime of the query context object is a little tricky. We - # need to enter the context here, so that we can immediately deal with any - # exceptions raised by query set-up. We eventually transfer control to an - # iterator consumed by FastAPI's StreamingResponse handler, which will - # start iterating after this function returns. So we use this ExitStack - # instance to hand over the context manager to the iterator. - with ExitStack() as exit_stack: - ctx = exit_stack.enter_context(_get_query_context(factory, request.query)) - spec = request.result_spec.to_result_spec(ctx.driver.universe) - - # We write the response incrementally, one page at a time, as - # newline-separated chunks of JSON. This allows clients to start - # reading results earlier and prevents the server from exhausting - # all its memory buffering rows from large queries. - output_generator = _stream_query_pages( - # Transfer control of the context manager to - # _stream_query_pages. - exit_stack.pop_all(), - ctx, - spec, - ) - return StreamingResponse( - output_generator, - media_type="application/jsonlines", - headers={ - # Instruct the Kubernetes ingress to not buffer the response, - # so that keep-alives reach the client promptly. - "X-Accel-Buffering": "no" - }, - ) - - # Mypy thinks that ExitStack might swallow an exception. - assert False, "This line is unreachable." - - -# Instead of declaring this as a sync generator with 'def', it's async to -# give us more control over the lifetime of exit_stack. StreamingResponse -# ensures that this async generator is cancelled if the client -# disconnects or another error occurs, ensuring that clean-up logic runs. -# -# If it was sync, it would get wrapped in an async function internal to -# FastAPI that does not guarantee that the generator is fully iterated or -# closed. -# (There is an example in the FastAPI docs showing StreamingResponse with a -# sync generator with a context manager, but after reading the FastAPI -# source code I believe that for sync generators it will leak the context -# manager if the client disconnects, and that it would be -# difficult/impossible for them to fix this in the general case within -# FastAPI.) -async def _stream_query_pages( - exit_stack: ExitStack, ctx: _QueryContext, spec: ResultSpec -) -> AsyncIterator[str]: + # We write the response incrementally, one page at a time, as + # newline-separated chunks of JSON. This allows clients to start + # reading results earlier and prevents the server from exhausting + # all its memory buffering rows from large queries. + output_generator = _stream_query_pages(request, factory) + return StreamingResponse( + output_generator, + media_type="application/jsonlines", + headers={ + # Instruct the Kubernetes ingress to not buffer the response, + # so that keep-alives reach the client promptly. + "X-Accel-Buffering": "no" + }, + ) + + +async def _stream_query_pages(request: QueryExecuteRequestModel, factory: Factory) -> AsyncIterator[str]: """Stream the query output with one page object per line, as newline-delimited JSON records in the "JSON Lines" format (https://jsonlines.org/). @@ -125,25 +90,46 @@ async def _stream_query_pages( When it takes longer than 15 seconds to get a response from the DB, sends a keep-alive message to prevent clients from timing out. """ - # Ensure that the database connection is cleaned up by taking control of - # exit_stack. - async with contextmanager_in_threadpool(exit_stack): - # `None` signals that there is no more data to send. - queue = asyncio.Queue[QueryExecuteResultData | None](1) - async with asyncio.TaskGroup() as tg: - tg.create_task(_enqueue_query_pages(ctx, spec, queue)) - async for message in _dequeue_query_pages_with_keepalive(queue): - yield message.model_dump_json() + "\n" - - -async def _enqueue_query_pages(ctx: _QueryContext, spec: ResultSpec, queue: asyncio.Queue) -> None: - async for page in iterate_in_threadpool(_retrieve_query_pages(ctx, spec)): - await queue.put(page) + # `None` signals that there is no more data to send. + queue = asyncio.Queue[QueryExecuteResultData | None](1) + async with asyncio.TaskGroup() as tg: + # Run a background task to read from the DB and insert the result pages + # into a queue. + tg.create_task(_enqueue_query_pages(queue, request, factory)) + # Read the result pages from the queue and send them to the client, + # inserting a keep-alive message every 15 seconds if we are waiting a + # long time for the database. + async for message in _dequeue_query_pages_with_keepalive(queue): + yield message.model_dump_json() + "\n" + + +async def _enqueue_query_pages( + queue: asyncio.Queue[QueryExecuteResultData | None], request: QueryExecuteRequestModel, factory: Factory +) -> None: + """Set up a QueryDriver to run the query, and copy the results into a + queue. Send `None` to the queue when there is no more data to read. + """ + try: + async with contextmanager_in_threadpool(_get_query_context(factory, request.query)) as ctx: + spec = request.result_spec.to_result_spec(ctx.driver.universe) + async for page in iterate_in_threadpool(_retrieve_query_pages(ctx, spec)): + await queue.put(page) + except ButlerUserError as e: + # If a user-facing error occurs, serialize it and send it to the + # client. + await queue.put(QueryErrorResultModel(error=serialize_butler_user_error(e))) # Signal that there is no more data to read. await queue.put(None) +def _retrieve_query_pages(ctx: _QueryContext, spec: ResultSpec) -> Iterator[QueryExecuteResultData]: + """Execute the database query and and return pages of results.""" + pages = ctx.driver.execute(spec, ctx.tree) + for page in pages: + yield convert_query_page(spec, page) + + async def _dequeue_query_pages_with_keepalive( queue: asyncio.Queue[QueryExecuteResultData | None], ) -> AsyncIterator[QueryExecuteResultData]: @@ -162,18 +148,6 @@ async def _dequeue_query_pages_with_keepalive( yield QueryKeepAliveModel() -def _retrieve_query_pages(ctx: _QueryContext, spec: ResultSpec) -> Iterator[QueryExecuteResultData]: - """Execute the database query and and return pages of results.""" - try: - pages = ctx.driver.execute(spec, ctx.tree) - for page in pages: - yield convert_query_page(spec, page) - except ButlerUserError as e: - # If a user-facing error occurs, serialize it and send it to the - # client. - yield QueryErrorResultModel(error=serialize_butler_user_error(e)) - - @query_router.post( "/v1/query/count", summary="Query the Butler database and return a count of rows that would be returned.", From 7efb63b8474acefde567dd1da52545ee8c5e9efa Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Fri, 30 Aug 2024 13:30:28 -0700 Subject: [PATCH 5/5] Add unit test for keepalive behavior --- .../daf/butler/remote_butler/_query_driver.py | 11 ++++- .../server/handlers/_external_query.py | 5 ++- tests/test_server.py | 42 ++++++++++++++++++- 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/python/lsst/daf/butler/remote_butler/_query_driver.py b/python/lsst/daf/butler/remote_butler/_query_driver.py index d290e62ffe..8bc9097cbf 100644 --- a/python/lsst/daf/butler/remote_butler/_query_driver.py +++ b/python/lsst/daf/butler/remote_butler/_query_driver.py @@ -148,7 +148,9 @@ def execute(self, result_spec: ResultSpec, tree: QueryTree) -> Iterator[ResultPa # response. for line in response.iter_lines(): result_chunk: QueryExecuteResultData = _QueryResultTypeAdapter.validate_json(line) - if result_chunk.type != "keep-alive": + if result_chunk.type == "keep-alive": + _received_keep_alive() + else: yield _convert_query_result_page(result_spec, result_chunk, universe) if self._closed: raise RuntimeError( @@ -280,3 +282,10 @@ def _convert_general_result(spec: GeneralResultSpec, model: GeneralResultModel) for row in model.rows ] return GeneralResultPage(spec=spec, rows=rows) + + +def _received_keep_alive() -> None: + """Do nothing. Gives a place for unit tests to hook in for testing + keepalive behavior. + """ + pass diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py index 8132702b44..56767f29a2 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py @@ -61,6 +61,9 @@ query_router = APIRouter() +# Alias this function so we can mock it during unit tests. +_timeout = asyncio.timeout + @query_router.post("/v1/query/execute", summary="Query the Butler database and return full results") async def query_execute( @@ -139,7 +142,7 @@ async def _dequeue_query_pages_with_keepalive( """ while True: try: - async with asyncio.timeout(15): + async with _timeout(15): message = await queue.get() if message is None: return diff --git a/tests/test_server.py b/tests/test_server.py index c6390a203c..33cae91c8c 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -34,6 +34,8 @@ try: # Failing to import any of these should disable the tests. + import lsst.daf.butler.remote_butler._query_driver + import lsst.daf.butler.remote_butler.server.handlers._external_query import safir.dependencies.logger from fastapi.testclient import TestClient from lsst.daf.butler.remote_butler import RemoteButler @@ -47,7 +49,7 @@ create_test_server = None reason_text = str(e) -from unittest.mock import NonCallableMock, patch +from unittest.mock import DEFAULT, NonCallableMock, patch from lsst.daf.butler import ( Butler, @@ -402,6 +404,28 @@ async def get_logger(): self.assertEqual(kwargs["clientRequestId"], "request-id") self.assertEqual(kwargs["user"], "user-name") + def test_query_keepalive(self): + """Test that long-running queries stream keep-alive messages to stop + the HTTP connection from closing before they are able to return + results. + """ + # Normally it takes 15 seconds for a timeout -- mock it to trigger + # immediately instead. + with patch.object( + lsst.daf.butler.remote_butler.server.handlers._external_query, "_timeout" + ) as mock_timeout: + # Hook into QueryDriver to track the number of keep-alives we have + # seen. + with patch.object( + lsst.daf.butler.remote_butler._query_driver, "_received_keep_alive" + ) as mock_keep_alive: + mock_timeout.side_effect = _timeout_twice() + with self.butler._query() as query: + datasets = list(query.datasets("bias", "imported_g")) + self.assertEqual(len(datasets), 3) + self.assertGreaterEqual(mock_timeout.call_count, 3) + self.assertGreaterEqual(mock_keep_alive.call_count, 2) + def _create_corrupted_dataset(repo: MetricTestRepo) -> DatasetRef: run = "corrupted-run" @@ -418,5 +442,21 @@ def _create_simple_dataset(butler: Butler) -> DatasetRef: return ref +def _timeout_twice(): + """Return a mock side-effect function that raises a timeout error the first + two times it is called. + """ + count = 0 + + def timeout(*args): + nonlocal count + count += 1 + if count <= 2: + raise TimeoutError() + return DEFAULT + + return timeout + + if __name__ == "__main__": unittest.main()