Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions flowmachine/flowmachine/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from contextvars import ContextVar, copy_context
from concurrent.futures import Executor, Future

import uuid
from contextlib import contextmanager
from typing import Callable, NamedTuple

Expand All @@ -34,9 +36,13 @@
action_request
except NameError:
action_request = ContextVar("action_request")
try:
interpreter_id
except NameError:
interpreter_id = ContextVar("interpreter_id", default=str(uuid.uuid4()))

_jupyter_context = (
dict()
_jupyter_context = dict(
interpreter_id=interpreter_id.get()
) # Required as a workaround for https://github.com/ipython/ipython/issues/11565

_is_notebook = False
Expand Down Expand Up @@ -129,6 +135,17 @@
raise NotConnectedError


def get_interpreter_id() -> str:
global _jupyter_context
try:
if _is_notebook:
return interpreter_id.get(_jupyter_context["interpreter_id"])

Check warning on line 142 in flowmachine/flowmachine/core/context.py

View check run for this annotation

Codecov / codecov/patch

flowmachine/flowmachine/core/context.py#L142

Added line #L142 was not covered by tests
else:
return interpreter_id.get()
except (LookupError, KeyError):
raise RuntimeError("No interpreter id.")

Check warning on line 146 in flowmachine/flowmachine/core/context.py

View check run for this annotation

Codecov / codecov/patch

flowmachine/flowmachine/core/context.py#L145-L146

Added lines #L145 - L146 were not covered by tests


def submit_to_executor(func: Callable, *args, **kwargs) -> Future:
"""
Submit a callable to the current context's executor pool and
Expand Down Expand Up @@ -178,6 +195,9 @@
db.set(connection)
executor.set(executor_pool)
redis_connection.set(redis_conn)
from flowmachine.core.init import _register_exit_handlers

_register_exit_handlers(redis_conn)


@contextmanager
Expand Down Expand Up @@ -207,6 +227,9 @@
db_token = db.set(connection)
redis_token = redis_connection.set(redis_conn)
executor_token = executor.set(executor_pool)
from flowmachine.core.init import _register_exit_handlers

_register_exit_handlers(redis_conn)
try:
yield
finally:
Expand Down
16 changes: 15 additions & 1 deletion flowmachine/flowmachine/core/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
From a developer perspective, this is where one-time operations
should live - for example configuring loggers.
"""
import atexit

import warnings
from contextlib import contextmanager

Expand All @@ -22,12 +24,15 @@

import flowmachine
from flowmachine.core import Connection
from flowmachine.core.context import bind_context, context, get_db
from flowmachine.core.context import bind_context, context, get_db, get_redis
from flowmachine.core.errors import NotConnectedError
from flowmachine.core.logging import set_log_level
from get_secret_or_env_var import environ, getenv

from flowmachine.core.query_manager import release_managed

logger = structlog.get_logger("flowmachine.debug", submodule=__name__)
_exit_handlers_registered = False


@contextmanager
Expand Down Expand Up @@ -321,4 +326,13 @@ def _do_connect(
print(
f"Flowdb running on: {flowdb_host}:{flowdb_port}/flowdb (connecting user: {flowdb_user})"
)
_register_exit_handlers(redis_connection)
return conn, thread_pool, redis_connection


def _register_exit_handlers(redis_connection):
import signal

atexit.register(release_managed, None, None, redis_connection)
signal.signal(signal.SIGTERM, lambda sig, frame: release_managed(redis_connection))
signal.signal(signal.SIGINT, lambda sig, frame: release_managed(redis_connection))
6 changes: 5 additions & 1 deletion flowmachine/flowmachine/core/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import rapidjson
import structlog
import sys
from flowmachine.core.context import get_action_request
from flowmachine.core.context import get_action_request, get_interpreter_id

__all__ = ["init_logging", "set_log_level"]

Expand All @@ -25,6 +25,10 @@ def action_request_processor(_, __, event_dict):
)
except LookupError:
pass
event_dict = dict(
**event_dict,
interpreter_id=get_interpreter_id(),
)
return event_dict


Expand Down
22 changes: 19 additions & 3 deletions flowmachine/flowmachine/core/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
submit_to_executor,
)
from flowmachine.core.errors.flowmachine_errors import QueryResetFailedException
from flowmachine.core.query_manager import get_manager
from flowmachine.core.query_state import QueryStateMachine
from abc import ABCMeta, abstractmethod

Expand Down Expand Up @@ -579,12 +580,27 @@ def to_sql(
unstored_dependencies_graph(self)
) # Need to ensure we're behind our deps in the queue

ddl_ops_func = self._make_sql

logger.debug("Attempting to queue query", query_id=self.query_id)
current_state, changed_to_queue = QueryStateMachine(
get_redis(), self.query_id, get_db().conn_id
).enqueue()
logger.debug(
f"Attempted to enqueue query '{self.query_id}', query state is now {current_state} and change happened {'here and now' if changed_to_queue else 'elsewhere'}."
)
if changed_to_queue:
logger.debug("Queued", query_id=self.query_id)
else:
logger.debug(
"Not queued", query_id=self.query_id, query_state=current_state
)
try:
logger.debug(
"Managed elsewhere.",
manager=get_manager(self.query_id, get_db().conn_id, get_redis()),
)
except AttributeError:
pass # Not being managed

# name, redis, query, connection, ddl_ops_func, write_func, schema = None, sleep_duration = 1
store_future = submit_to_executor(
write_query_to_cache,
name=name,
Expand Down
66 changes: 66 additions & 0 deletions flowmachine/flowmachine/core/query_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import List

import datetime
import structlog
from contextlib import contextmanager
from redis import StrictRedis

from flowmachine.core.context import get_redis, get_interpreter_id

logger = structlog.get_logger("flowmachine.debug", submodule=__name__)


@contextmanager
def managing(query_id: str, db_id: str, redis_connection: StrictRedis):
set_managing(query_id, db_id, redis_connection)
yield
unset_managing(query_id, db_id, redis_connection)

Check warning on line 17 in flowmachine/flowmachine/core/query_manager.py

View check run for this annotation

Codecov / codecov/patch

flowmachine/flowmachine/core/query_manager.py#L15-L17

Added lines #L15 - L17 were not covered by tests


def set_managing(query_id: str, db_id: str, redis_connection: StrictRedis) -> None:
logger.debug("Setting manager.", query_id=query_id, db_id=db_id)
redis_connection.hset("manager", f"{db_id}-{query_id}", get_interpreter_id())
redis_connection.hset(
f"managing:{get_interpreter_id()}",
f"{db_id}-{query_id}",
datetime.datetime.now().isoformat(),
)
logger.debug("Set manager.", query_id=query_id, db_id=db_id)


def unset_managing(query_id: str, db_id: str, redis_connection: StrictRedis) -> None:
logger.debug("Releasing manager.", query_id=query_id, db_id=db_id)
redis_connection.hdel("manager", f"{db_id}-{query_id}")
redis_connection.hdel(f"managing:{get_interpreter_id()}", f"{db_id}-{query_id}")
logger.debug("Released manager.", query_id=query_id, db_id=db_id)


def get_managed(redis_connection: StrictRedis) -> List[str]:
return [
k.decode()
for k in redis_connection.hgetall(f"managing:{get_interpreter_id()}").keys()
]


def get_manager(
query_id: str, connection_id: str, redis_connection: StrictRedis
) -> str:
redis_connection.hget("manager", f"{connection_id}-{query_id}").decode()


def release_managed(redis_connection) -> None:
from flowmachine.core.query_state import QueryStateMachine

logger.error("Releasing managed queries.")
for query_id in get_managed(redis_connection):
conn_id, query_id = query_id.split("-")
qsm = QueryStateMachine(
db_id=conn_id,
query_id=query_id,
redis_client=redis_connection,
)
qsm.raise_error()
qsm.cancel()
qsm.reset()
qsm.finish_resetting()
unset_managing(query_id, conn_id, redis_connection)
23 changes: 19 additions & 4 deletions flowmachine/flowmachine/core/query_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from redis import StrictRedis

from flowmachine.core.query_manager import set_managing, unset_managing
from flowmachine.utils import _sleep

logger = logging.getLogger("flowmachine").getChild(__name__)
Expand Down Expand Up @@ -103,7 +104,9 @@ class QueryStateMachine:

def __init__(self, redis_client: StrictRedis, query_id: str, db_id: str):
self.query_id = query_id
self.db_id = db_id
must_populate = redis_client.get(f"finist:{db_id}:{query_id}-state") is None
self.redis_client = redis_client
self.state_machine = Finist(
redis_client, f"{db_id}:{query_id}-state", QueryState.KNOWN
)
Expand Down Expand Up @@ -270,7 +273,10 @@ def cancel(self):
triggered the query to be cancelled with this call

"""
return self.trigger_event(QueryEvent.CANCEL)
state, changed = self.trigger_event(QueryEvent.CANCEL)
if changed:
unset_managing(self.query_id, self.db_id, self.redis_client)
return state, changed

def enqueue(self):
"""
Expand All @@ -283,7 +289,10 @@ def enqueue(self):
triggered the query to be queued with this call

"""
return self.trigger_event(QueryEvent.QUEUE)
state, changed = self.trigger_event(QueryEvent.QUEUE)
if changed:
set_managing(self.query_id, self.db_id, self.redis_client)
return state, changed

def raise_error(self):
"""
Expand All @@ -296,7 +305,10 @@ def raise_error(self):
marked the query as erroring with this call

"""
return self.trigger_event(QueryEvent.ERROR)
state, changed = self.trigger_event(QueryEvent.ERROR)
if changed:
unset_managing(self.query_id, self.db_id, self.redis_client)
return state, changed

def execute(self):
"""
Expand All @@ -322,7 +334,10 @@ def finish(self):
marked the query as finished with this call

"""
return self.trigger_event(QueryEvent.FINISH)
state, changed = self.trigger_event(QueryEvent.FINISH)
if changed:
unset_managing(self.query_id, self.db_id, self.redis_client)
return state, changed

def reset(self):
"""
Expand Down
20 changes: 15 additions & 5 deletions flowmachine/flowmachine/core/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@
import flowmachine
from flowmachine.core import Query, Connection
from flowmachine.core.cache import watch_and_shrink_cache
from flowmachine.core.context import get_db, get_executor, action_request_context
from flowmachine.core.context import (
get_db,
get_executor,
action_request_context,
get_redis,
)
from flowmachine.core.query_manager import get_managed, release_managed
from flowmachine.core.query_state import QueryStateMachine
from flowmachine.utils import convert_dict_keys_to_strings
from .exceptions import FlowmachineServerError
from .zmq_helpers import ZMQReply
Expand Down Expand Up @@ -266,10 +273,13 @@ def main():
logger.info("Enabling asyncio's debugging mode.")

# Run receive loop which receives zmq messages and sends back replies
asyncio.run(
recv(config=config),
debug=config.debug_mode,
) # note: asyncio.run() requires Python 3.7+
try:
asyncio.run(
recv(config=config),
debug=config.debug_mode,
) # note: asyncio.run() requires Python 3.7+
finally:
release_managed(get_redis())


if __name__ == "__main__":
Expand Down
9 changes: 9 additions & 0 deletions flowmachine/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,15 @@ def hset(self, key, current, next):
except KeyError:
self._store[key] = {current.encode(): next.encode()}

def hgetall(self, key):
return self._store[key]

def hdel(self, key, name):
try:
del self._store[key][name.encode()]
except KeyError:
return 0

def set(self, key, value):
self._store[key] = value.encode()

Expand Down
11 changes: 9 additions & 2 deletions flowmachine/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import importlib

import pytest
from flowmachine.core.errors import NotConnectedError
from flowmachine.core.context import get_interpreter_id, get_executor


@pytest.fixture(autouse=True)
Expand All @@ -28,6 +28,13 @@ def flowmachine_connect(): # Override the autoused fixture from the parent
pass


def test_consistent_interpreter_id():
import flowmachine

flowmachine.connect()
assert get_executor().submit(get_interpreter_id).result() == get_interpreter_id()


def test_context_manager():
import flowmachine

Expand Down Expand Up @@ -81,7 +88,7 @@ def test_notebook_detection_without_ipython_shell(monkeypatch):
importlib.reload(flowmachine.core.context)

assert not flowmachine.core.context._is_notebook
assert len(flowmachine.core.context._jupyter_context) == 0
assert len(flowmachine.core.context._jupyter_context) == 1


def test_notebook_workaround(monkeypatch):
Expand Down
Loading