diff --git a/flowmachine/flowmachine/core/context.py b/flowmachine/flowmachine/core/context.py index c9bcf485c6..86dca9c281 100644 --- a/flowmachine/flowmachine/core/context.py +++ b/flowmachine/flowmachine/core/context.py @@ -9,6 +9,8 @@ from contextvars import ContextVar, copy_context from concurrent.futures import Executor, Future + +import uuid from contextlib import contextmanager from typing import Callable, NamedTuple @@ -34,9 +36,13 @@ action_request except NameError: action_request = ContextVar("action_request") +try: + interpreter_id +except NameError: + interpreter_id = ContextVar("interpreter_id", default=str(uuid.uuid4())) -_jupyter_context = ( - dict() +_jupyter_context = dict( + interpreter_id=interpreter_id.get() ) # Required as a workaround for https://github.com/ipython/ipython/issues/11565 _is_notebook = False @@ -129,6 +135,17 @@ def get_executor() -> Executor: raise NotConnectedError +def get_interpreter_id() -> str: + global _jupyter_context + try: + if _is_notebook: + return interpreter_id.get(_jupyter_context["interpreter_id"]) + else: + return interpreter_id.get() + except (LookupError, KeyError): + raise RuntimeError("No interpreter id.") + + def submit_to_executor(func: Callable, *args, **kwargs) -> Future: """ Submit a callable to the current context's executor pool and @@ -178,6 +195,9 @@ def bind_context( db.set(connection) executor.set(executor_pool) redis_connection.set(redis_conn) + from flowmachine.core.init import _register_exit_handlers + + _register_exit_handlers(redis_conn) @contextmanager @@ -207,6 +227,9 @@ def context(connection: Connection, executor_pool: Executor, redis_conn: StrictR db_token = db.set(connection) redis_token = redis_connection.set(redis_conn) executor_token = executor.set(executor_pool) + from flowmachine.core.init import _register_exit_handlers + + _register_exit_handlers(redis_conn) try: yield finally: diff --git a/flowmachine/flowmachine/core/init.py b/flowmachine/flowmachine/core/init.py index a9dc3a1ca8..958faa6a43 100644 --- a/flowmachine/flowmachine/core/init.py +++ b/flowmachine/flowmachine/core/init.py @@ -10,6 +10,8 @@ From a developer perspective, this is where one-time operations should live - for example configuring loggers. """ +import atexit + import warnings from contextlib import contextmanager @@ -22,12 +24,15 @@ import flowmachine from flowmachine.core import Connection -from flowmachine.core.context import bind_context, context, get_db +from flowmachine.core.context import bind_context, context, get_db, get_redis from flowmachine.core.errors import NotConnectedError from flowmachine.core.logging import set_log_level from get_secret_or_env_var import environ, getenv +from flowmachine.core.query_manager import release_managed + logger = structlog.get_logger("flowmachine.debug", submodule=__name__) +_exit_handlers_registered = False @contextmanager @@ -321,4 +326,13 @@ def _do_connect( print( f"Flowdb running on: {flowdb_host}:{flowdb_port}/flowdb (connecting user: {flowdb_user})" ) + _register_exit_handlers(redis_connection) return conn, thread_pool, redis_connection + + +def _register_exit_handlers(redis_connection): + import signal + + atexit.register(release_managed, None, None, redis_connection) + signal.signal(signal.SIGTERM, lambda sig, frame: release_managed(redis_connection)) + signal.signal(signal.SIGINT, lambda sig, frame: release_managed(redis_connection)) diff --git a/flowmachine/flowmachine/core/logging.py b/flowmachine/flowmachine/core/logging.py index aba0beaf7c..e4e589fffa 100644 --- a/flowmachine/flowmachine/core/logging.py +++ b/flowmachine/flowmachine/core/logging.py @@ -6,7 +6,7 @@ import rapidjson import structlog import sys -from flowmachine.core.context import get_action_request +from flowmachine.core.context import get_action_request, get_interpreter_id __all__ = ["init_logging", "set_log_level"] @@ -25,6 +25,10 @@ def action_request_processor(_, __, event_dict): ) except LookupError: pass + event_dict = dict( + **event_dict, + interpreter_id=get_interpreter_id(), + ) return event_dict diff --git a/flowmachine/flowmachine/core/query.py b/flowmachine/flowmachine/core/query.py index a1959014b0..27e0b371b2 100644 --- a/flowmachine/flowmachine/core/query.py +++ b/flowmachine/flowmachine/core/query.py @@ -30,6 +30,7 @@ submit_to_executor, ) from flowmachine.core.errors.flowmachine_errors import QueryResetFailedException +from flowmachine.core.query_manager import get_manager from flowmachine.core.query_state import QueryStateMachine from abc import ABCMeta, abstractmethod @@ -579,12 +580,27 @@ def to_sql( unstored_dependencies_graph(self) ) # Need to ensure we're behind our deps in the queue + ddl_ops_func = self._make_sql + + logger.debug("Attempting to queue query", query_id=self.query_id) current_state, changed_to_queue = QueryStateMachine( get_redis(), self.query_id, get_db().conn_id ).enqueue() - logger.debug( - f"Attempted to enqueue query '{self.query_id}', query state is now {current_state} and change happened {'here and now' if changed_to_queue else 'elsewhere'}." - ) + if changed_to_queue: + logger.debug("Queued", query_id=self.query_id) + else: + logger.debug( + "Not queued", query_id=self.query_id, query_state=current_state + ) + try: + logger.debug( + "Managed elsewhere.", + manager=get_manager(self.query_id, get_db().conn_id, get_redis()), + ) + except AttributeError: + pass # Not being managed + + # name, redis, query, connection, ddl_ops_func, write_func, schema = None, sleep_duration = 1 store_future = submit_to_executor( write_query_to_cache, name=name, diff --git a/flowmachine/flowmachine/core/query_manager.py b/flowmachine/flowmachine/core/query_manager.py new file mode 100644 index 0000000000..c313806ad5 --- /dev/null +++ b/flowmachine/flowmachine/core/query_manager.py @@ -0,0 +1,66 @@ +from typing import List + +import datetime +import structlog +from contextlib import contextmanager +from redis import StrictRedis + +from flowmachine.core.context import get_redis, get_interpreter_id + +logger = structlog.get_logger("flowmachine.debug", submodule=__name__) + + +@contextmanager +def managing(query_id: str, db_id: str, redis_connection: StrictRedis): + set_managing(query_id, db_id, redis_connection) + yield + unset_managing(query_id, db_id, redis_connection) + + +def set_managing(query_id: str, db_id: str, redis_connection: StrictRedis) -> None: + logger.debug("Setting manager.", query_id=query_id, db_id=db_id) + redis_connection.hset("manager", f"{db_id}-{query_id}", get_interpreter_id()) + redis_connection.hset( + f"managing:{get_interpreter_id()}", + f"{db_id}-{query_id}", + datetime.datetime.now().isoformat(), + ) + logger.debug("Set manager.", query_id=query_id, db_id=db_id) + + +def unset_managing(query_id: str, db_id: str, redis_connection: StrictRedis) -> None: + logger.debug("Releasing manager.", query_id=query_id, db_id=db_id) + redis_connection.hdel("manager", f"{db_id}-{query_id}") + redis_connection.hdel(f"managing:{get_interpreter_id()}", f"{db_id}-{query_id}") + logger.debug("Released manager.", query_id=query_id, db_id=db_id) + + +def get_managed(redis_connection: StrictRedis) -> List[str]: + return [ + k.decode() + for k in redis_connection.hgetall(f"managing:{get_interpreter_id()}").keys() + ] + + +def get_manager( + query_id: str, connection_id: str, redis_connection: StrictRedis +) -> str: + redis_connection.hget("manager", f"{connection_id}-{query_id}").decode() + + +def release_managed(redis_connection) -> None: + from flowmachine.core.query_state import QueryStateMachine + + logger.error("Releasing managed queries.") + for query_id in get_managed(redis_connection): + conn_id, query_id = query_id.split("-") + qsm = QueryStateMachine( + db_id=conn_id, + query_id=query_id, + redis_client=redis_connection, + ) + qsm.raise_error() + qsm.cancel() + qsm.reset() + qsm.finish_resetting() + unset_managing(query_id, conn_id, redis_connection) diff --git a/flowmachine/flowmachine/core/query_state.py b/flowmachine/flowmachine/core/query_state.py index d79d63bbee..e2a9044681 100644 --- a/flowmachine/flowmachine/core/query_state.py +++ b/flowmachine/flowmachine/core/query_state.py @@ -16,6 +16,7 @@ from redis import StrictRedis +from flowmachine.core.query_manager import set_managing, unset_managing from flowmachine.utils import _sleep logger = logging.getLogger("flowmachine").getChild(__name__) @@ -103,7 +104,9 @@ class QueryStateMachine: def __init__(self, redis_client: StrictRedis, query_id: str, db_id: str): self.query_id = query_id + self.db_id = db_id must_populate = redis_client.get(f"finist:{db_id}:{query_id}-state") is None + self.redis_client = redis_client self.state_machine = Finist( redis_client, f"{db_id}:{query_id}-state", QueryState.KNOWN ) @@ -270,7 +273,10 @@ def cancel(self): triggered the query to be cancelled with this call """ - return self.trigger_event(QueryEvent.CANCEL) + state, changed = self.trigger_event(QueryEvent.CANCEL) + if changed: + unset_managing(self.query_id, self.db_id, self.redis_client) + return state, changed def enqueue(self): """ @@ -283,7 +289,10 @@ def enqueue(self): triggered the query to be queued with this call """ - return self.trigger_event(QueryEvent.QUEUE) + state, changed = self.trigger_event(QueryEvent.QUEUE) + if changed: + set_managing(self.query_id, self.db_id, self.redis_client) + return state, changed def raise_error(self): """ @@ -296,7 +305,10 @@ def raise_error(self): marked the query as erroring with this call """ - return self.trigger_event(QueryEvent.ERROR) + state, changed = self.trigger_event(QueryEvent.ERROR) + if changed: + unset_managing(self.query_id, self.db_id, self.redis_client) + return state, changed def execute(self): """ @@ -322,7 +334,10 @@ def finish(self): marked the query as finished with this call """ - return self.trigger_event(QueryEvent.FINISH) + state, changed = self.trigger_event(QueryEvent.FINISH) + if changed: + unset_managing(self.query_id, self.db_id, self.redis_client) + return state, changed def reset(self): """ diff --git a/flowmachine/flowmachine/core/server/server.py b/flowmachine/flowmachine/core/server/server.py index b21e6e90cc..46ae6dcef5 100644 --- a/flowmachine/flowmachine/core/server/server.py +++ b/flowmachine/flowmachine/core/server/server.py @@ -22,7 +22,14 @@ import flowmachine from flowmachine.core import Query, Connection from flowmachine.core.cache import watch_and_shrink_cache -from flowmachine.core.context import get_db, get_executor, action_request_context +from flowmachine.core.context import ( + get_db, + get_executor, + action_request_context, + get_redis, +) +from flowmachine.core.query_manager import get_managed, release_managed +from flowmachine.core.query_state import QueryStateMachine from flowmachine.utils import convert_dict_keys_to_strings from .exceptions import FlowmachineServerError from .zmq_helpers import ZMQReply @@ -266,10 +273,13 @@ def main(): logger.info("Enabling asyncio's debugging mode.") # Run receive loop which receives zmq messages and sends back replies - asyncio.run( - recv(config=config), - debug=config.debug_mode, - ) # note: asyncio.run() requires Python 3.7+ + try: + asyncio.run( + recv(config=config), + debug=config.debug_mode, + ) # note: asyncio.run() requires Python 3.7+ + finally: + release_managed(get_redis()) if __name__ == "__main__": diff --git a/flowmachine/tests/conftest.py b/flowmachine/tests/conftest.py index 103461aed5..8409537aae 100644 --- a/flowmachine/tests/conftest.py +++ b/flowmachine/tests/conftest.py @@ -297,6 +297,15 @@ def hset(self, key, current, next): except KeyError: self._store[key] = {current.encode(): next.encode()} + def hgetall(self, key): + return self._store[key] + + def hdel(self, key, name): + try: + del self._store[key][name.encode()] + except KeyError: + return 0 + def set(self, key, value): self._store[key] = value.encode() diff --git a/flowmachine/tests/test_context.py b/flowmachine/tests/test_context.py index f72f0759b8..04b9805924 100644 --- a/flowmachine/tests/test_context.py +++ b/flowmachine/tests/test_context.py @@ -5,7 +5,7 @@ import importlib import pytest -from flowmachine.core.errors import NotConnectedError +from flowmachine.core.context import get_interpreter_id, get_executor @pytest.fixture(autouse=True) @@ -28,6 +28,13 @@ def flowmachine_connect(): # Override the autoused fixture from the parent pass +def test_consistent_interpreter_id(): + import flowmachine + + flowmachine.connect() + assert get_executor().submit(get_interpreter_id).result() == get_interpreter_id() + + def test_context_manager(): import flowmachine @@ -81,7 +88,7 @@ def test_notebook_detection_without_ipython_shell(monkeypatch): importlib.reload(flowmachine.core.context) assert not flowmachine.core.context._is_notebook - assert len(flowmachine.core.context._jupyter_context) == 0 + assert len(flowmachine.core.context._jupyter_context) == 1 def test_notebook_workaround(monkeypatch): diff --git a/flowmachine/tests/test_query_manager.py b/flowmachine/tests/test_query_manager.py new file mode 100644 index 0000000000..016af16acb --- /dev/null +++ b/flowmachine/tests/test_query_manager.py @@ -0,0 +1,38 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +""" +Tests for the query manger. +""" +from flowmachine.core.context import get_interpreter_id +from flowmachine.core.query_manager import set_managing, unset_managing, release_managed + + +def test_set_managing(dummy_redis): + set_managing("DUMMY_ID", "DUMMY_DB_ID", dummy_redis) + assert ( + dummy_redis.get("manager")[b"DUMMY_DB_ID-DUMMY_ID"] + == get_interpreter_id().encode() + ) + assert b"DUMMY_DB_ID-DUMMY_ID" in dummy_redis.get( + f"managing:{get_interpreter_id()}" + ) + + +def test_unset_managing(dummy_redis): + test_set_managing(dummy_redis) + unset_managing("DUMMY_ID", "DUMMY_DB_ID", dummy_redis) + assert b"DUMMY_DB_ID-DUMMY_ID" not in dummy_redis.get("manager") + assert b"DUMMY_DB_ID-DUMMY_ID" not in dummy_redis.get( + f"managing:{get_interpreter_id()}" + ) + + +def test_release_managed(dummy_redis): + test_set_managing(dummy_redis) + release_managed(dummy_redis) + assert b"DUMMY_DB_ID-DUMMY_ID" not in dummy_redis.get("manager") + assert b"DUMMY_DB_ID-DUMMY_ID" not in dummy_redis.get( + f"managing:{get_interpreter_id()}" + ) diff --git a/flowmachine/tests/test_query_state.py b/flowmachine/tests/test_query_state.py index 539f5c287b..008b6c759f 100644 --- a/flowmachine/tests/test_query_state.py +++ b/flowmachine/tests/test_query_state.py @@ -13,12 +13,13 @@ import flowmachine from flowmachine.core import Query -from flowmachine.core.context import get_redis, get_db +from flowmachine.core.context import get_redis, get_db, get_interpreter_id from flowmachine.core.errors.flowmachine_errors import ( QueryCancelledException, QueryErroredException, QueryResetFailedException, ) +from flowmachine.core.query_manager import set_managing from flowmachine.core.query_state import QueryStateMachine, QueryState, QueryEvent import flowmachine.utils @@ -129,6 +130,7 @@ def test_query_cancellation(start_state, succeeds, dummy_redis): """Test the cancel method works as expected.""" state_machine = QueryStateMachine(dummy_redis, "DUMMY_QUERY_ID", get_db().conn_id) dummy_redis.set(state_machine.state_machine._name, start_state) + set_managing("DUMMY_QUERY_ID", "DUMMY_DB_ID", dummy_redis) state_machine.cancel() assert succeeds == state_machine.is_cancelled