From 583c2c151010550173147efdbde610436a850e58 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Fri, 5 Jul 2024 19:07:21 +0200 Subject: [PATCH 1/4] Split off the Qt parts of the UpdateAgent So that we can use it in the API, where we don't want to depend on Qt. Also removed the `description` argument from `variable_set()` because it's currently unused. --- damnit/gui/kafka.py | 79 --------------------------------------- damnit/gui/main_window.py | 71 ++++++++++++++++++++++++++--------- damnit/kafka.py | 50 +++++++++++++++++++++++++ tests/test_gui.py | 11 +++--- 4 files changed, 110 insertions(+), 101 deletions(-) delete mode 100644 damnit/gui/kafka.py create mode 100644 damnit/kafka.py diff --git a/damnit/gui/kafka.py b/damnit/gui/kafka.py deleted file mode 100644 index 010425c6..00000000 --- a/damnit/gui/kafka.py +++ /dev/null @@ -1,79 +0,0 @@ -import pickle -import logging - -from kafka import KafkaConsumer, KafkaProducer -from PyQt5 import QtCore - -from ..backend.db import MsgKind, msg_dict -from ..definitions import UPDATE_BROKERS, UPDATE_TOPIC - -log = logging.getLogger(__name__) - - -class UpdateAgent(QtCore.QObject): - message = QtCore.pyqtSignal(object) - - def __init__(self, db_id: str) -> None: - QtCore.QObject.__init__(self) - self.update_topic = UPDATE_TOPIC.format(db_id) - - self.kafka_cns = KafkaConsumer( - self.update_topic, bootstrap_servers=UPDATE_BROKERS - ) - self.kafka_prd = KafkaProducer(bootstrap_servers=UPDATE_BROKERS, - value_serializer=lambda d: pickle.dumps(d)) - self.running = False - - def listen_loop(self) -> None: - self.running = True - - while self.running: - # Note: this doesn't throw an exception on timeout, it just returns - # an empty dict. - topic_messages = self.kafka_cns.poll(timeout_ms=100) - - for topic, messages in topic_messages.items(): - for msg in messages: - try: - unpickled_msg = pickle.loads(msg.value) - except Exception: - log.error("Kafka event could not be un-pickled.", exc_info=True) - continue - - self.message.emit(unpickled_msg) - - def run_values_updated(self, proposal, run, name, value): - message = msg_dict(MsgKind.run_values_updated, - { - "proposal": proposal, - "run": run, - "values": { - name: value - } - }) - - # Note: the send() function returns a future that we don't await - # immediately, but we call kafka_prd.flush() in stop() which will ensure - # that all messages are sent. - self.kafka_prd.send(self.update_topic, message) - - def variable_set(self, name, title, description, variable_type): - message = msg_dict(MsgKind.variable_set, - { - "name": name, - "title": title, - "attributes": None, - "type": variable_type - }) - self.kafka_prd.send(self.update_topic, message) - - def stop(self): - self.running = False - self.kafka_prd.flush(timeout=10) - - -if __name__ == "__main__": - monitor = UpdateAgent("tcp://localhost:5556") - - for record in monitor.kafka_cns: - print(record.value.decode()) diff --git a/damnit/gui/main_window.py b/damnit/gui/main_window.py index 2a97c039..dfd39b3f 100644 --- a/damnit/gui/main_window.py +++ b/damnit/gui/main_window.py @@ -12,6 +12,7 @@ import numpy as np import pandas as pd import xarray as xr +from kafka import KafkaConsumer from kafka.errors import NoBrokersAvailable from pandas.api.types import infer_dtype from plotly.graph_objects import Figure as PlotlyFigure @@ -31,7 +32,7 @@ from ..definitions import UPDATE_BROKERS from ..util import StatusbarStylesheet, fix_data_for_plotting, icon_path from .editor import ContextTestResult, Editor -from .kafka import UpdateAgent +from ..kafka import UpdateProducer from .open_dialog import OpenDBDialog from .new_context_dialog import NewContextFileDialog from .plot import Canvas, Plot @@ -87,7 +88,7 @@ def __init__(self, context_dir: Path = None, connect_to_kafka: bool = True): self.setCentralWidget(self._tab_widget) self.table = None - + self.zulip_messenger = None self._create_view() @@ -124,11 +125,15 @@ def closeEvent(self, event): def stop_update_listener_thread(self): if self._updates_thread is not None: - self.update_agent.stop() + self.update_consumer.stop() self._updates_thread.exit() self._updates_thread.wait() self._updates_thread = None + # We call flush() here to ensure that all messages are sent since we + # don't pass `flush=True` anywhere else. + self.update_producer.kafka_prd.flush(timeout=10) + def center_window(self): """ Center and resize the window to the screen the cursor is placed on. @@ -288,7 +293,7 @@ def add_variable(self, name, title, variable_type, description="", before=None): self.table.add_editable_column(name) if self._connect_to_kafka: - self.update_agent.variable_set(name, title, description, variable_type) + self.update_producer.variable_set(name, title, variable_type) def open_column_dialog(self): if self._columns_dialog is None: @@ -373,12 +378,12 @@ def _create_menu_bar(self) -> None: action_precreate_runs = QtWidgets.QAction("Pre-create new runs", self) action_precreate_runs.triggered.connect(self.precreate_runs_dialog) tableMenu = menu_bar.addMenu("Table") - + tableMenu.addAction(action_columns) tableMenu.addAction(self.action_autoscroll) tableMenu.addAction(action_precreate_runs) - - #jump to run + + #jump to run menu_bar_right = QtWidgets.QMenuBar(self) searchMenu = menu_bar_right.addMenu( QtGui.QIcon(icon_path("search_icon.png")), "&Search Run") @@ -393,7 +398,7 @@ def _create_menu_bar(self) -> None: searchMenu.addAction(actionWidget) menu_bar.setCornerWidget(menu_bar_right, Qt.TopRightCorner) - + def scroll_to_run(self, run): try: run = int(run) @@ -502,7 +507,8 @@ def _updates_thread_launcher(self) -> None: assert self.db_id is not None try: - self.update_agent = UpdateAgent(self.db_id) + self.update_consumer = UpdateConsumer(self.db.kafka_topic) + self.update_producer = UpdateProducer(self.db.kafka_topic) except NoBrokersAvailable: QtWidgets.QMessageBox.warning(self, "Broker connection failed", f"Could not connect to any Kafka brokers at: {' '.join(UPDATE_BROKERS)}\n\n" + @@ -510,10 +516,10 @@ def _updates_thread_launcher(self) -> None: return self._updates_thread = QtCore.QThread() - self.update_agent.moveToThread(self._updates_thread) + self.update_consumer.moveToThread(self._updates_thread) - self._updates_thread.started.connect(self.update_agent.listen_loop) - self.update_agent.message.connect(self.handle_update) + self._updates_thread.started.connect(self.update_consumer.listen_loop) + self.update_consumer.message.connect(self.handle_update) QtCore.QTimer.singleShot(0, self._updates_thread.start) def _set_comment_date(self): @@ -725,7 +731,7 @@ def _create_view(self) -> None: plot_vertical_layout.addLayout(plot_parameters_horizontal_layout) plotting_group.setLayout(plot_vertical_layout) - + collapsible.add_widget(plotting_group) vertical_layout.setSpacing(0) @@ -853,7 +859,7 @@ def save_value(self, prop, run, name, value): log.debug("Saving data for variable %s for prop %d run %d", name, prop, run) self.db.set_variable(prop, run, name, ReducedData(value)) if self._connect_to_kafka: - self.update_agent.run_values_updated(prop, run, name, value) + self.update_producer.run_values_updated(prop, run, name, value) def save_time_comment(self, comment_id, value): if self.db is None: @@ -862,11 +868,11 @@ def save_time_comment(self, comment_id, value): log.debug("Saving time-based comment ID %d", comment_id) self.db.change_standalone_comment(comment_id, value) - + def check_zulip_messenger(self): if not isinstance(self.zulip_messenger, ZulipMessenger): - self.zulip_messenger = ZulipMessenger(self) - + self.zulip_messenger = ZulipMessenger(self) + if not self.zulip_messenger.ok: self.zulip_messenger = None return False @@ -944,6 +950,37 @@ def __init__(self, file_path: Path, parent=None): self.setCentralWidget(self.text_edit) self.resize(1000, 800) +class UpdateConsumer(QtCore.QObject): + message = QtCore.pyqtSignal(object) + + def __init__(self, topic: str) -> None: + QtCore.QObject.__init__(self) + + self.kafka_cns = KafkaConsumer( + topic, bootstrap_servers=UPDATE_BROKERS + ) + self.running = False + + def listen_loop(self) -> None: + self.running = True + + while self.running: + # Note: this doesn't throw an exception on timeout, it just returns + # an empty dict. + topic_messages = self.kafka_cns.poll(timeout_ms=100) + + for topic, messages in topic_messages.items(): + for msg in messages: + try: + unpickled_msg = pickle.loads(msg.value) + except Exception: + log.error("Kafka event could not be un-pickled.", exc_info=True) + continue + + self.message.emit(unpickled_msg) + + def stop(self): + self.running = False def prompt_setup_db_and_backend(context_dir: Path, prop_no=None, parent=None): if not db_path(context_dir).is_file(): diff --git a/damnit/kafka.py b/damnit/kafka.py new file mode 100644 index 00000000..9fd50093 --- /dev/null +++ b/damnit/kafka.py @@ -0,0 +1,50 @@ +import pickle +import logging + +from kafka import KafkaProducer + +from .backend.db import MsgKind, msg_dict +from .definitions import UPDATE_BROKERS + +log = logging.getLogger(__name__) + + +class UpdateProducer: + def __init__(self, default_topic) -> None: + self.default_topic = default_topic + + self.kafka_prd = KafkaProducer(bootstrap_servers=UPDATE_BROKERS, + value_serializer=lambda d: pickle.dumps(d)) + + def run_values_updated(self, proposal, run, name, value, flush=False, topic=None): + if topic is None: + topic = self.default_topic + + message = msg_dict(MsgKind.run_values_updated, + { + "proposal": proposal, + "run": run, + "values": { + name: value + } + }) + self.kafka_prd.send(topic, message) + + if flush: + self.kafka_prd.flush() + + def variable_set(self, name, title, variable_type, flush=False, topic=None): + if topic is None: + topic = self.default_topic + + message = msg_dict(MsgKind.variable_set, + { + "name": name, + "title": title, + "attributes": None, + "type": variable_type + }) + self.kafka_prd.send(topic, message) + + if flush: + self.kafka_prd.flush() diff --git a/tests/test_gui.py b/tests/test_gui.py index 242317a9..b02f2d33 100644 --- a/tests/test_gui.py +++ b/tests/test_gui.py @@ -37,16 +37,17 @@ def pid_dead(pid): def test_connect_to_kafka(mock_db, qtbot): db_dir, db = mock_db - pkg = "damnit.gui.kafka" + consumer_import = "damnit.gui.main_window.KafkaConsumer" + producer_import = "damnit.kafka.KafkaProducer" - with patch(f"{pkg}.KafkaConsumer") as kafka_cns, \ - patch(f"{pkg}.KafkaProducer") as kafka_prd: + with patch(consumer_import) as kafka_cns, \ + patch(producer_import) as kafka_prd: MainWindow(db_dir, False).close() kafka_cns.assert_not_called() kafka_prd.assert_not_called() - with patch(f"{pkg}.KafkaConsumer") as kafka_cns, \ - patch(f"{pkg}.KafkaProducer") as kafka_prd: + with patch(consumer_import) as kafka_cns, \ + patch(producer_import) as kafka_prd: MainWindow(db_dir, True).close() kafka_cns.assert_called_once() kafka_prd.assert_called_once() From c1c2fccb48e542c13f986388d3b2daaa59d08648 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Fri, 5 Jul 2024 19:08:34 +0200 Subject: [PATCH 2/4] Remove log statement from DamnitDB.get_user_variables() --- damnit/backend/db.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/damnit/backend/db.py b/damnit/backend/db.py index 06176851..bf7b58ff 100644 --- a/damnit/backend/db.py +++ b/damnit/backend/db.py @@ -209,7 +209,7 @@ def get_user_variables(self): attributes=rr["attributes"], ) user_variables[var_name] = new_var - log.debug("Loaded %d user variables", len(user_variables)) + return user_variables def update_computed_variables(self, vars: dict): From 30b27ce5f13d9d5a9edaa80fbe66599a24c15980 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Fri, 5 Jul 2024 19:09:18 +0200 Subject: [PATCH 3/4] Support validating/converting Python objects for editable variables --- damnit/backend/user_variables.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/damnit/backend/user_variables.py b/damnit/backend/user_variables.py index cabf1b73..846baed2 100644 --- a/damnit/backend/user_variables.py +++ b/damnit/backend/user_variables.py @@ -10,17 +10,23 @@ class ValueType: examples = None + py_type = None + def __str__(self): return self.type_name @classmethod def parse(cls, input: str): - return input + return cls.py_type(input) @classmethod def from_db_value(cls, value): return value + @classmethod + def to_db_value(cls, value): + return cls.py_type(value) + class BooleanValueType(ValueType): type_name = "boolean" @@ -29,6 +35,8 @@ class BooleanValueType(ValueType): examples = ["True", "T", "true", "1", "False", "F", "f", "0"] + py_type = bool + _valid_values = { "true": True, "yes": True, @@ -61,6 +69,10 @@ def from_db_value(cls, value): return None return bool(value) + @classmethod + def to_db_value(cls, value): + return value if isinstance(value, bool) else None + class IntegerValueType(ValueType): type_name = "integer" @@ -69,9 +81,7 @@ class IntegerValueType(ValueType): examples = ["-7", "-2", "0", "10", "34"] - @classmethod - def parse(cls, input: str): - return int(input) + py_type = int class NumberValueType(ValueType): type_name = "number" @@ -80,10 +90,7 @@ class NumberValueType(ValueType): examples = ["-34.1e10", "-7.1", "-4", "0.0", "3.141592653589793", "85.4E7"] - @classmethod - def parse(cls, input: str): - return float(input) - + py_type = float class StringValueType(ValueType): type_name = "string" @@ -92,6 +99,12 @@ class StringValueType(ValueType): examples = ["Broken", "Dark frame", "test_frame"] + py_type = str + + @classmethod + def to_db_value(cls, value): + return value if isinstance(value, str) else None + value_types_by_name = {tt.type_name: tt for tt in [ BooleanValueType(), IntegerValueType(), NumberValueType(), StringValueType() From a8fdf04a9e1408dc37c25629b3425e3945f61e57 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Fri, 5 Jul 2024 19:16:36 +0200 Subject: [PATCH 4/4] Add support for writing to user-editable variables through the API --- damnit/api.py | 77 ++++++++++++++++++++++++++++++++++++++++++----- docs/api.md | 9 ++++++ pyproject.toml | 2 +- tests/test_api.py | 63 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 142 insertions(+), 9 deletions(-) diff --git a/damnit/api.py b/damnit/api.py index ab3ecd6e..de5a82f8 100644 --- a/damnit/api.py +++ b/damnit/api.py @@ -10,7 +10,8 @@ import plotly.io as pio import xarray as xr -from .backend.db import BlobTypes, DamnitDB +from .kafka import UpdateProducer +from .backend.db import BlobTypes, ReducedData, DamnitDB # This is a copy of damnit.ctxsupport.ctxrunner.DataType, purely so that we can @@ -38,6 +39,13 @@ def find_proposal(propno): raise FileNotFoundError("Couldn't find proposal dir for {!r}".format(propno)) +# This variable is meant to be an instance of UpdateProducer, but we lazily +# initialize it because creating the producer because takes ~100ms. Which isn't +# very much, but it may otherwise be created hundreds of times if used in a +# context file so it's better to avoid it where possible. +UPDATE_PRODUCER = None + + class VariableData: """Represents a variable for a single run. @@ -48,7 +56,7 @@ class VariableData: def __init__(self, name: str, title: str, proposal: int, run: int, h5_path: Path, data_format_version: int, - db: DamnitDB, db_only: bool): + db: DamnitDB, db_only: bool, missing: bool): self._name = name self._title = title self._proposal = proposal @@ -57,6 +65,7 @@ def __init__(self, name: str, title: str, self._data_format_version = data_format_version self._db = db self._db_only = db_only + self._missing = missing @property def name(self) -> str: @@ -145,6 +154,39 @@ def read(self): # Otherwise, return a Numpy array return group["data"][()] + def write(self, value, send_update=True): + """Write a value to a user-editable variable. + + This may throw an exception if converting `value` to the type of the + editable variable fails, e.g. `db[100]["number"] = "foo"` will fail if + the `number` variable has a numeric type. + + Args: + send_update (bool): Whether or not to send an update after + writing to the database. Don't use this unless you know what + you're doing, it may disappear in the future. + """ + if not self._db_only: + raise RuntimeError(f"Cannot write to variable '{self.name}', it's not a user-editable variable.") + + # Convert the input + user_variable = self._db.get_user_variables()[self.name] + variable_type = user_variable.get_type_class() + value = variable_type.to_db_value(value) + if value is None: + raise ValueError(f"Forbidden conversion of value '{value!r}' to type '{variable_type.py_type}'") + + # Write to the database + self._db.set_variable(self.proposal, self.run, self.name, ReducedData(value)) + + if send_update: + global UPDATE_PRODUCER + if UPDATE_PRODUCER is None: + UPDATE_PRODUCER = UpdateProducer(None) + + UPDATE_PRODUCER.variable_set(self.name, self.title, variable_type.type_name, + flush=True, topic=self._db.kafka_topic) + def summary(self): """Read the summary data for a variable. @@ -204,21 +246,40 @@ def file(self) -> Path: """The path to the HDF5 file for the run.""" return self._h5_path - def __getitem__(self, name): + def _get_variable(self, name): key_locs = self._key_locations() names_to_titles = self._var_titles() titles_to_names = { title: name for name, title in names_to_titles.items() } - if name not in key_locs and name not in titles_to_names: - raise KeyError(f"Variable data for '{name!r}' not found for p{self.proposal}, r{self.run}") - if name in titles_to_names: name = titles_to_names[name] - return VariableData(name, names_to_titles[name], + missing = name not in key_locs + user_variables = self._db.get_user_variables() + if missing and name in user_variables: + key_locs[name] = True + elif missing and name not in user_variables: + raise KeyError(f"Variable data for '{name!r}' not found for p{self.proposal}, r{self.run}") + + return VariableData(name, names_to_titles.get(name), self.proposal, self.run, self._h5_path, self._data_format_version, - self._db, key_locs[name]) + self._db, key_locs[name], + missing) + + def __getitem__(self, name): + variable = self._get_variable(name) + if variable._missing: + raise KeyError(f"Variable data for '{name!r}' not found for p{self.proposal}, r{self.run}") + + return variable + + def __setitem__(self, name, value): + variable = self._get_variable(name) + + # The environment variable is basically only useful for tests + send_update = bool(int(os.environ.get("DAMNIT_SEND_UPDATE", 1))) + variable.write(value, send_update) def _key_locations(self): # Read keys from the HDF5 file diff --git a/docs/api.md b/docs/api.md index 64df3059..f880bd37 100644 --- a/docs/api.md +++ b/docs/api.md @@ -31,6 +31,15 @@ data = myvar.read() summary = myvar.summary() ``` +You can also write to [user-editable +variables](gui.md#adding-user-editable-variables): +```python +run_vars["myvar"] = 42 + +# An alternative style would be: +myvar.write(42) +``` + ## API reference ::: damnit.Damnit diff --git a/pyproject.toml b/pyproject.toml index 42455403..740bac30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ readme = "README.md" dependencies = [ "h5netcdf", "h5py", + "kafka-python-ng", "orjson", # used in plotly for faster json serialization "pandas", "plotly", @@ -27,7 +28,6 @@ dependencies = [ backend = [ "EXtra-data", "ipython", - "kafka-python-ng", "kaleido", # used in plotly to convert figures to images "matplotlib", "numpy", diff --git a/tests/test_api.py b/tests/test_api.py index dfb8becb..bb72ff25 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,6 +1,8 @@ +import os import subprocess from pathlib import Path from textwrap import dedent +from unittest.mock import patch import numpy as np import plotly.express as px @@ -10,6 +12,7 @@ from damnit import Damnit, RunVariables from damnit.context import ContextFile +from damnit.backend.user_variables import UserEditableVariable from .helpers import extract_mock_run @@ -90,6 +93,10 @@ def test_variable_data(mock_db_with_data, monkeypatch): damnit = Damnit(db_dir) rv = damnit[1] + # We disable updates from VariableData.write() by default so we can use the + # convenient RunVariables.__setitem__() method. + monkeypatch.setenv("DAMNIT_SEND_UPDATE", "0") + # Insert a DataSet variable dataset_code = """ from damnit_ctx import Variable @@ -130,6 +137,62 @@ def dataset(run): assert isinstance(fig, PlotlyFigure) assert fig == px.bar(x=["a", "b", "c"], y=[1, 3, 2]) + # It shouldn't be possible to write to non-editable variables or the default + # variables from DAMNIT. + with pytest.raises(RuntimeError): + rv["dataset"] = 1 + with pytest.raises(RuntimeError): + rv["start_time"] = 1 + + # It also shouldn't be possible to write to variables that don't exist. We + # have to test this because RunVariables.__setitem__() will allow creating + # VariableData objects for variables that are missing from the run but do + # exist in the database. + with pytest.raises(KeyError): + rv["blah"] = 1 + + # Test setting an editable value + db.add_user_variable(UserEditableVariable("foo", "Foo", "number")) + rv["foo"] = 3.14 + + # Test sending Kafka updates + foo_var = rv["foo"] + with patch("damnit.kafka.KafkaProducer") as kafka_prd: + foo_var.write(42, send_update=True) + kafka_prd.assert_called_once() + +# These are smoke tests to ensure that writing of all variable types succeed, +# tests of special cases are above in test_variable_data. +# +# Note that we allow any value to be converted to strings for convenience, so +# its `bad_input` value is None. +@pytest.mark.parametrize("variable_name,good_input,bad_input", + [("boolean", True, "foo"), + ("integer", 42, "foo"), + ("number", 3.14, "foo"), + ("stringy", "foo", None)]) +def test_writing(variable_name, good_input, bad_input, mock_db_with_data, monkeypatch): + db_dir, db = mock_db_with_data + monkeypatch.chdir(db_dir) + monkeypatch.setenv("DAMNIT_SEND_UPDATE", "0") + damnit = Damnit(db_dir) + rv = damnit[1] + + # There's already a `string` variable in the test context file so we can't + # reuse the name as the type for our editable variable. + variable_type = "string" if variable_name == "stringy" else variable_name + + # Add the user-editable variable + user_var = UserEditableVariable(variable_name, variable_name.capitalize(), variable_type) + db.add_user_variable(user_var) + + rv[variable_name] = good_input + assert rv[variable_name].read() == good_input + + if bad_input is not None: + with pytest.raises(ValueError): + rv[variable_name] = bad_input + def test_api_dependencies(venv): package_path = Path(__file__).parent.parent venv.install(package_path)