From 0f2afe970d59dfe1db2a63401c15171dc9023306 Mon Sep 17 00:00:00 2001 From: "David H. Irving" Date: Fri, 30 Aug 2024 15:36:05 -0700 Subject: [PATCH] better query cancellation --- .../butler/registry/databases/postgresql.py | 4 ++ .../daf/butler/registry/databases/sqlite.py | 4 ++ .../butler/registry/interfaces/_database.py | 24 ++++++++++++ .../server/handlers/_external_query.py | 37 +++++++++++++++---- 4 files changed, 61 insertions(+), 8 deletions(-) diff --git a/python/lsst/daf/butler/registry/databases/postgresql.py b/python/lsst/daf/butler/registry/databases/postgresql.py index 51df4e9f3c..0277a9dd17 100644 --- a/python/lsst/daf/butler/registry/databases/postgresql.py +++ b/python/lsst/daf/butler/registry/databases/postgresql.py @@ -396,6 +396,10 @@ def apply_any_aggregate(self, column: sqlalchemy.ColumnElement[Any]) -> sqlalche # would become a String column in the output. return sqlalchemy.cast(sqlalchemy.func.any_value(column), column.type) + def _cancel_running_query(self, connection: sqlalchemy.engine.interfaces.DBAPIConnection) -> None: + # This is a psycopg2-specific extension method. + connection.cancel() # type: ignore + class _RangeTimespanType(sqlalchemy.TypeDecorator): """A single-column `Timespan` representation usable only with diff --git a/python/lsst/daf/butler/registry/databases/sqlite.py b/python/lsst/daf/butler/registry/databases/sqlite.py index 6db681e0e7..edd3ef4f15 100644 --- a/python/lsst/daf/butler/registry/databases/sqlite.py +++ b/python/lsst/daf/butler/registry/databases/sqlite.py @@ -412,6 +412,10 @@ def apply_any_aggregate(self, column: sqlalchemy.ColumnElement[Any]) -> sqlalche # arbitrary value picked if there is more than one. return column + def _cancel_running_query(self, connection: sqlalchemy.engine.interfaces.DBAPIConnection) -> None: + # This is a pysqlite-specific extension method. + connection.interrupt() # type: ignore + filename: str | None """Name of the file this database is connected to (`str` or `None`). diff --git a/python/lsst/daf/butler/registry/interfaces/_database.py b/python/lsst/daf/butler/registry/interfaces/_database.py index 0e69a14cbb..9436873644 100644 --- a/python/lsst/daf/butler/registry/interfaces/_database.py +++ b/python/lsst/daf/butler/registry/interfaces/_database.py @@ -1995,6 +1995,30 @@ def apply_any_aggregate(self, column: sqlalchemy.ColumnElement[Any]) -> sqlalche """ raise NotImplementedError() + def cancel_running_query(self) -> None: + """Attempt to cancel an in-progress query that is using this database + connection. + + Notes + ----- + If no query is active, does nothing. This may be called from a + different thread than the one performing the query. The underlying + database driver functions for cancellation are generally not guaranteed + to succeed. + """ + connection = self._session_connection + if connection is not None: + db = connection.connection.dbapi_connection + if db is not None: + self._cancel_running_query(db) + + @abstractmethod + def _cancel_running_query(self, connection: sqlalchemy.engine.interfaces.DBAPIConnection) -> None: + """Driver-specific inner implementation for ``cancel_running_query`` + above. + """ + raise NotImplementedError() + origin: int """An integer ID that should be used as the default for any datasets, quanta, or other entities that use a (autoincrement, origin) compound diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py index 56767f29a2..265a501b21 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py @@ -53,7 +53,8 @@ ) from ...._exceptions import ButlerUserError -from ....queries.driver import QueryDriver, QueryTree, ResultSpec +from ....direct_query_driver import DirectQueryDriver +from ....queries.driver import QueryTree, ResultSpec from ..._errors import serialize_butler_user_error from .._dependencies import factory_dependency from .._factory import Factory @@ -98,25 +99,45 @@ async def _stream_query_pages(request: QueryExecuteRequestModel, factory: Factor async with asyncio.TaskGroup() as tg: # Run a background task to read from the DB and insert the result pages # into a queue. - tg.create_task(_enqueue_query_pages(queue, request, factory)) + tg.create_task(_execute_query(queue, request, factory)) # Read the result pages from the queue and send them to the client, # inserting a keep-alive message every 15 seconds if we are waiting a # long time for the database. async for message in _dequeue_query_pages_with_keepalive(queue): yield message.model_dump_json() + "\n" + print("closed") -async def _enqueue_query_pages( + +async def _execute_query( queue: asyncio.Queue[QueryExecuteResultData | None], request: QueryExecuteRequestModel, factory: Factory ) -> None: """Set up a QueryDriver to run the query, and copy the results into a queue. Send `None` to the queue when there is no more data to read. """ + async with contextmanager_in_threadpool(_get_query_context(factory, request.query)) as ctx: + # Do the actual work in another task so we can explicitly handle + # cancellation. The database calls in _retrieve_query_pages can + # block for a very long time waiting for a response from the DB. + # Since that is a synchronous call in another thread, the `await` + # can't be cancelled until the sync call finishes. So we have to + # forcibly cancel the database query to get the sync call to abort. + async with asyncio.TaskGroup() as tg: + task = tg.create_task(_enqueue_query_pages(queue, request, ctx)) + try: + await asyncio.wait_for(task, None) + except asyncio.CancelledError: + ctx.driver.db.cancel_running_query() + raise + + +async def _enqueue_query_pages( + queue: asyncio.Queue[QueryExecuteResultData | None], request: QueryExecuteRequestModel, ctx: _QueryContext +) -> None: try: - async with contextmanager_in_threadpool(_get_query_context(factory, request.query)) as ctx: - spec = request.result_spec.to_result_spec(ctx.driver.universe) - async for page in iterate_in_threadpool(_retrieve_query_pages(ctx, spec)): - await queue.put(page) + spec = request.result_spec.to_result_spec(ctx.driver.universe) + async for page in iterate_in_threadpool(_retrieve_query_pages(ctx, spec)): + await queue.put(page) except ButlerUserError as e: # If a user-facing error occurs, serialize it and send it to the # client. @@ -219,5 +240,5 @@ def _get_query_context(factory: Factory, query: QueryInputs) -> Iterator[_QueryC class _QueryContext(NamedTuple): - driver: QueryDriver + driver: DirectQueryDriver tree: QueryTree