From 24649428e8de29a47b4ccb58b4a1cda2b641611d Mon Sep 17 00:00:00 2001 From: Fred Moolekamp Date: Thu, 21 Sep 2023 17:48:01 -0400 Subject: [PATCH 1/3] Initial code and docs commit --- .github/pull_request_template.md | 4 + .github/workflows/build.yaml | 55 ++++ .github/workflows/build_docs.yaml | 41 +++ .github/workflows/formatting.yaml | 11 + .github/workflows/lint.yaml | 11 + .github/workflows/yamllint.yaml | 11 + .pre-commit-config.yaml | 32 ++ doc/.gitignore | 10 + doc/SConscript | 3 + doc/conf.py | 12 + doc/doxygen.conf.in | 0 doc/index.rst | 12 + doc/lsst.rubintv.analysis.service/index.rst | 40 +++ doc/manifest.yaml | 12 + mypy.ini | 22 ++ pyproject.toml | 13 +- python/lsst/__init__.py | 2 +- python/lsst/rubintv/__init__.py | 2 +- python/lsst/rubintv/analysis/__init__.py | 2 +- .../lsst/rubintv/analysis/service/__init__.py | 1 + .../lsst/rubintv/analysis/service/butler.py | 29 ++ .../lsst/rubintv/analysis/service/client.py | 81 +++++ .../lsst/rubintv/analysis/service/command.py | 255 ++++++++++++++++ .../lsst/rubintv/analysis/service/database.py | 271 +++++++++++++++++ python/lsst/rubintv/analysis/service/query.py | 151 ++++++++++ requirements.txt | 12 + scripts/config.yaml | 8 + scripts/mock_server.py | 259 ++++++++++++++++ scripts/rubintv_worker.py | 54 ++++ setup.cfg | 2 +- tests/schema.yaml | 42 +++ tests/test_command.py | 210 +++++++++++++ tests/test_database.py | 103 +++++++ tests/test_query.py | 280 ++++++++++++++++++ tests/utils.py | 197 ++++++++++++ 35 files changed, 2241 insertions(+), 9 deletions(-) create mode 100644 .github/pull_request_template.md create mode 100644 .github/workflows/build.yaml create mode 100644 .github/workflows/build_docs.yaml create mode 100644 .github/workflows/formatting.yaml create mode 100644 .github/workflows/lint.yaml create mode 100644 .github/workflows/yamllint.yaml create mode 100644 .pre-commit-config.yaml create mode 100644 doc/.gitignore create mode 100644 doc/SConscript create mode 100644 doc/conf.py create mode 100644 doc/doxygen.conf.in create mode 100644 doc/index.rst create mode 100644 doc/lsst.rubintv.analysis.service/index.rst create mode 100644 doc/manifest.yaml create mode 100644 mypy.ini create mode 100644 python/lsst/rubintv/analysis/service/__init__.py create mode 100644 python/lsst/rubintv/analysis/service/butler.py create mode 100644 python/lsst/rubintv/analysis/service/client.py create mode 100644 python/lsst/rubintv/analysis/service/command.py create mode 100644 python/lsst/rubintv/analysis/service/database.py create mode 100644 python/lsst/rubintv/analysis/service/query.py create mode 100644 requirements.txt create mode 100644 scripts/config.yaml create mode 100644 scripts/mock_server.py create mode 100644 scripts/rubintv_worker.py create mode 100644 tests/schema.yaml create mode 100644 tests/test_command.py create mode 100644 tests/test_database.py create mode 100644 tests/test_query.py create mode 100644 tests/utils.py diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..7b345f4 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,4 @@ +## Checklist + +- [ ] ran Jenkins +- [ ] added a release note for user-visible changes to `doc/changes` diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml new file mode 100644 index 0000000..0055df1 --- /dev/null +++ b/.github/workflows/build.yaml @@ -0,0 +1,55 @@ +name: build_and_test + +on: + push: + branches: + - main + tags: + - "*" + pull_request: + +jobs: + build_and_test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + # Need to clone everything for the git tags. + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + cache: "pip" + cache-dependency-path: "setup.cfg" + + - name: Install yaml + run: sudo apt-get install libyaml-dev + + - name: Install prereqs for setuptools + run: pip install wheel + + # We have two cores so we can speed up the testing with xdist + - name: Install xdist, openfiles and flake8 for pytest + run: > + pip install pytest-xdist pytest-openfiles pytest-flake8 + pytest-cov "flake8<5" + + - name: Build and install + run: pip install -v -e . + + - name: Install documenteer + run: pip install 'documenteer[pipelines]<0.7' + + - name: Run tests + run: > + pytest -r a -v -n 3 --open-files --cov=tests + --cov=lsst.rubintv.analysis.service + --cov-report=xml --cov-report=term + --doctest-modules --doctest-glob="*.rst" + + - name: Upload coverage to codecov + uses: codecov/codecov-action@v2 + with: + file: ./coverage.xml diff --git a/.github/workflows/build_docs.yaml b/.github/workflows/build_docs.yaml new file mode 100644 index 0000000..75d1dac --- /dev/null +++ b/.github/workflows/build_docs.yaml @@ -0,0 +1,41 @@ +name: docs + +on: + push: + branches: + - main + pull_request: + +jobs: + build_sphinx_docs: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + # Need to clone everything for the git tags. + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + cache: "pip" + cache-dependency-path: "setup.cfg" + + - name: Update pip/wheel infrastructure + run: | + python -m pip install --upgrade pip + pip install wheel + + - name: Build and install + run: pip install -v -e . + + - name: Show compiled files + run: ls python/lsst/rubintv/analysis/service + + - name: Install documenteer + run: pip install 'documenteer[pipelines]<0.7' + + - name: Build documentation + working-directory: ./doc + run: package-docs build diff --git a/.github/workflows/formatting.yaml b/.github/workflows/formatting.yaml new file mode 100644 index 0000000..27f34a6 --- /dev/null +++ b/.github/workflows/formatting.yaml @@ -0,0 +1,11 @@ +name: Check Python formatting + +on: + push: + branches: + - main + pull_request: + +jobs: + call-workflow: + uses: lsst/rubin_workflows/.github/workflows/formatting.yaml@main diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 0000000..796ef92 --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,11 @@ +name: lint + +on: + push: + branches: + - main + pull_request: + +jobs: + call-workflow: + uses: lsst/rubin_workflows/.github/workflows/lint.yaml@main diff --git a/.github/workflows/yamllint.yaml b/.github/workflows/yamllint.yaml new file mode 100644 index 0000000..76ad875 --- /dev/null +++ b/.github/workflows/yamllint.yaml @@ -0,0 +1,11 @@ +name: Lint YAML Files + +on: + push: + branches: + - main + pull_request: + +jobs: + call-workflow: + uses: lsst/rubin_workflows/.github/workflows/yamllint.yaml@main diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a07ff8a --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,32 @@ +repos: + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-yaml + args: + - "--unsafe" + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + # It is recommended to specify the latest version of Python + + # supported by your project here, or alternatively use + + # pre-commit's default_language_version, see + + # https://pre-commit.com/#top_level-default_language_version + + language_version: python3.10 + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) + - repo: https://github.com/PyCQA/flake8 + rev: 6.0.0 + hooks: + - id: flake8 diff --git a/doc/.gitignore b/doc/.gitignore new file mode 100644 index 0000000..ad2c2bb --- /dev/null +++ b/doc/.gitignore @@ -0,0 +1,10 @@ +# Doxygen products +html +xml +*.tag +*.inc +doxygen.conf + +# Sphinx products +_build +py-api diff --git a/doc/SConscript b/doc/SConscript new file mode 100644 index 0000000..61b554a --- /dev/null +++ b/doc/SConscript @@ -0,0 +1,3 @@ +# -*- python -*- +from lsst.sconsUtils import scripts +scripts.BasicSConscript.doc() diff --git a/doc/conf.py b/doc/conf.py new file mode 100644 index 0000000..0dcae36 --- /dev/null +++ b/doc/conf.py @@ -0,0 +1,12 @@ +"""Sphinx configuration file for an LSST stack package. +This configuration only affects single-package Sphinx documentation builds. +For more information, see: +https://developer.lsst.io/stack/building-single-package-docs.html +""" + +from documenteer.conf.pipelinespkg import * + +project = "rubintv_analysis_service" +html_theme_options["logotext"] = project +html_title = project +html_short_title = project diff --git a/doc/doxygen.conf.in b/doc/doxygen.conf.in new file mode 100644 index 0000000..e69de29 diff --git a/doc/index.rst b/doc/index.rst new file mode 100644 index 0000000..f7713f5 --- /dev/null +++ b/doc/index.rst @@ -0,0 +1,12 @@ +############################################## +rubintv_analysis_service documentation preview +############################################## + +.. This page is for local development only. It isn't published to pipelines.lsst.io. + +.. Link the index pages of package and module documentation directions (listed in manifest.yaml). + +.. toctree:: + :maxdepth: 1 + + lsst.rubintv.analysis.service/index diff --git a/doc/lsst.rubintv.analysis.service/index.rst b/doc/lsst.rubintv.analysis.service/index.rst new file mode 100644 index 0000000..3707f25 --- /dev/null +++ b/doc/lsst.rubintv.analysis.service/index.rst @@ -0,0 +1,40 @@ +.. py:currentmodule:: lsst.rubintv.analysis.service + +.. _lsst.rubintv.analysis.service: + +############################# +lsst.rubintv.analysis.service +############################# + +.. Paragraph that describes what this Python module does and links to related modules and frameworks. + +.. _lsst.rubintv.analysis.service-using: + +Using lsst.rubintv.analysis.service +======================= + +toctree linking to topics related to using the module's APIs. + +.. toctree:: + :maxdepth: 2 + +.. _lsst.rubintv.analysis.service-contributing: + +Contributing +============ + +``lsst.rubintv.analysis.service`` is developed at https://github.com/lsst-ts/rubintv_analysis_service. + +.. If there are topics related to developing this module (rather than using it), link to this from a toctree placed here. + +.. .. toctree:: +.. :maxdepth: 2 + +.. _lsst.rubintv.analysis.service-pyapi: + +Python API reference +==================== + +.. automodapi:: lsst.rubintv.analysis.service + :no-main-docstr: + :no-inheritance-diagram: diff --git a/doc/manifest.yaml b/doc/manifest.yaml new file mode 100644 index 0000000..222eb97 --- /dev/null +++ b/doc/manifest.yaml @@ -0,0 +1,12 @@ +# Documentation manifest. + +# List of names of Python modules in this package. +# For each module there is a corresponding module doc subdirectory. +modules: + - "lsst.rubintv.analysis.service" + +# Name of the static content directories (subdirectories of `_static`). +# Static content directories are usually named after the package. +# Most packages do not need a static content directory (leave commented out). +# statics: +# - "_static/example_standalone" diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..3d32554 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,22 @@ +[mypy] +warn_unused_configs = True +warn_redundant_casts = True +plugins = pydantic.mypy + +[mypy-astropy.*] +ignore_missing_imports = True + +[mypy-matplotlib.*] +ignore_missing_imports = True + +[mypy-numpy.*] +ignore_missing_imports = True + +[mypy-scipy.*] +ignore_missing_imports = True + +[mypy-sqlalchemy.*] +ignore_missing_imports = True + +[mypy-yaml.*] +ignore_missing_imports = True diff --git a/pyproject.toml b/pyproject.toml index 3b951cd..98176c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,12 +30,18 @@ dependencies = [ "scipy", "matplotlib", "pydantic", - "pyyaml" + "pyyaml", + "sqlalchemy", + "astropy", + "websocket-client", + "lsst-daf-butler", + # temporary dependency for testing + "tornado", ] #dynamic = ["version"] [project.urls] -"Homepage" = "https://github.com/lsst/rubintv_analysis_service" +"Homepage" = "https://github.com/lsst-ts/rubintv_analysis_service" [project.optional-dependencies] test = [ @@ -65,9 +71,6 @@ line_length = 110 [tool.lsst_versions] write_to = "python/lsst/rubintv/analysis/service/version.py" -[tool.pytest.ini_options] -addopts = "--flake8" -flake8-ignore = ["W503", "E203"] # The matplotlib test may not release font files. open_files_ignore = ["*.ttf"] diff --git a/python/lsst/__init__.py b/python/lsst/__init__.py index eb1f6e6..f77af49 100644 --- a/python/lsst/__init__.py +++ b/python/lsst/__init__.py @@ -1,3 +1,3 @@ import pkgutil -__path__ = pkgutil.extend_path(__path__, __name__) \ No newline at end of file +__path__ = pkgutil.extend_path(__path__, __name__) diff --git a/python/lsst/rubintv/__init__.py b/python/lsst/rubintv/__init__.py index eb1f6e6..f77af49 100644 --- a/python/lsst/rubintv/__init__.py +++ b/python/lsst/rubintv/__init__.py @@ -1,3 +1,3 @@ import pkgutil -__path__ = pkgutil.extend_path(__path__, __name__) \ No newline at end of file +__path__ = pkgutil.extend_path(__path__, __name__) diff --git a/python/lsst/rubintv/analysis/__init__.py b/python/lsst/rubintv/analysis/__init__.py index eb1f6e6..f77af49 100644 --- a/python/lsst/rubintv/analysis/__init__.py +++ b/python/lsst/rubintv/analysis/__init__.py @@ -1,3 +1,3 @@ import pkgutil -__path__ = pkgutil.extend_path(__path__, __name__) \ No newline at end of file +__path__ = pkgutil.extend_path(__path__, __name__) diff --git a/python/lsst/rubintv/analysis/service/__init__.py b/python/lsst/rubintv/analysis/service/__init__.py new file mode 100644 index 0000000..e144f74 --- /dev/null +++ b/python/lsst/rubintv/analysis/service/__init__.py @@ -0,0 +1 @@ +from . import command, database, query diff --git a/python/lsst/rubintv/analysis/service/butler.py b/python/lsst/rubintv/analysis/service/butler.py new file mode 100644 index 0000000..35dd967 --- /dev/null +++ b/python/lsst/rubintv/analysis/service/butler.py @@ -0,0 +1,29 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from dataclasses import dataclass + +from .command import BaseCommand + + +@dataclass +class ExampleButlerCommand(BaseCommand): + pass diff --git a/python/lsst/rubintv/analysis/service/client.py b/python/lsst/rubintv/analysis/service/client.py new file mode 100644 index 0000000..9504b0b --- /dev/null +++ b/python/lsst/rubintv/analysis/service/client.py @@ -0,0 +1,81 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +import logging + +import sqlalchemy +import yaml +from lsst.daf.butler import Butler +from websocket import WebSocketApp + +from .command import DatabaseConnection, execute_command + +logger = logging.getLogger("lsst.rubintv.analysis.service.client") + + +def on_error(ws: WebSocketApp, error: str) -> None: + """Error received from the server.""" + print(f"\033[91mError: {error}\033[0m") + + +def on_close(ws: WebSocketApp, close_status_code: str, close_msg: str) -> None: + """Connection closed by the server.""" + print("\033[93mConnection closed\033[0m") + + +def run_worker(address: str, port: int, connection_info: dict[str, dict]) -> None: + """Run the worker and connect to the rubinTV server. + + Parameters + ---------- + address : + Address of the rubinTV web app. + port : + Port of the rubinTV web app websockets. + connection_info : + Connections . + """ + # Load the database connection information + databases: dict[str, DatabaseConnection] = {} + + for name, info in connection_info["databases"].items(): + with open(info["schema"], "r") as file: + engine = sqlalchemy.create_engine(info["url"]) + schema = yaml.safe_load(file) + databases[name] = DatabaseConnection(schema=schema, engine=engine) + + # Load the Butler (if one is available) + butler: Butler | None = None + if "butler" in connection_info: + repo = connection_info["butler"].pop("repo") + butler = Butler(repo, **connection_info["butler"]) + + def on_message(ws: WebSocketApp, message: str) -> None: + """Message received from the server.""" + response = execute_command(message, databases, butler) + ws.send(response) + + print(f"\033[92mConnecting to rubinTV at {address}:{port}\033[0m") + # Connect to the WebSocket server + ws = WebSocketApp( + f"ws://{address}:{port}/ws/worker", on_message=on_message, on_error=on_error, on_close=on_close + ) + ws.run_forever() + ws.close() diff --git a/python/lsst/rubintv/analysis/service/command.py b/python/lsst/rubintv/analysis/service/command.py new file mode 100644 index 0000000..c90e3ab --- /dev/null +++ b/python/lsst/rubintv/analysis/service/command.py @@ -0,0 +1,255 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import json +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import sqlalchemy +from lsst.daf.butler import Butler + +logger = logging.getLogger("lsst.rubintv.analysis.service.command") + + +def construct_error_message(error_name: str, description: str) -> str: + """Use a standard format for all error messages. + + Parameters + ---------- + error_name : + Name of the error. + description : + Description of the error. + + Returns + ------- + result : + JSON formatted string. + """ + return json.dumps( + { + "type": "error", + "content": { + "error": error_name, + "description": description, + }, + } + ) + + +def error_msg(error: Exception) -> str: + """Handle errors received while parsing or executing a command. + + Parameters + ---------- + error : + The error that was raised while parsing, executing, + or responding to a command. + + Returns + ------- + response : + The JSON formatted error message sent to the user. + + """ + if isinstance(error, json.decoder.JSONDecodeError): + return construct_error_message("JSON decoder error", error.args[0]) + + if isinstance(error, CommandParsingError): + return construct_error_message("parsing error", error.args[0]) + + if isinstance(error, CommandExecutionError): + return construct_error_message("execution error", error.args[0]) + + if isinstance(error, CommandResponseError): + return construct_error_message("command response error", error.args[0]) + + # We should always receive one of the above errors, so the code should + # never get to here. But we generate this response just in case something + # very unexpected happens, or (more likely) the code is altered in such a + # way that this line is it. + msg = "An unknown error occurred, you should never reach this message." + return construct_error_message(error.__class__.__name__, msg) + + +class CommandParsingError(Exception): + """An `~Exception` caused by an error in parsing a command and + constructing a response. + """ + + pass + + +class CommandExecutionError(Exception): + """An error occurred while executing a command.""" + + pass + + +class CommandResponseError(Exception): + """An error occurred while converting a command result to JSON""" + + pass + + +@dataclass +class DatabaseConnection: + """A connection to a database. + + Attributes + ---------- + engine : + The engine used to connect to the database. + schema : + The schema for the database. + """ + + engine: sqlalchemy.engine.Engine + schema: dict + + +@dataclass(kw_only=True) +class BaseCommand(ABC): + """Base class for commands. + + Attributes + ---------- + result : + The response generated by the command as a `dict` that can + be converted into JSON. + response_type : + The type of response that this command sends to the user. + This should be unique for each command. + """ + + result: dict | None = None + response_type: str + + @abstractmethod + def build_contents(self, databases: dict[str, DatabaseConnection], butler: Butler | None) -> dict: + """Build the contents of the command. + + Parameters + ---------- + databases : + The database connections. + butler : + A conencted Butler. + + Returns + ------- + contents : + The contents of the response to the user. + """ + pass + + def execute(self, databases: dict[str, DatabaseConnection], butler: Butler | None): + """Execute the command. + + This method does not return anything, buts sets the `result`, + the JSON formatted string that is sent to the user. + + Parameters + ---------- + databases : + The database connections. + butler : + A conencted Butler. + + """ + self.result = {"type": self.response_type, "content": self.build_contents(databases, butler)} + + def to_json(self): + """Convert the `result` into JSON.""" + if self.result is None: + raise CommandExecutionError(f"Null result for command {self.__class__.__name__}") + return json.dumps(self.result) + + @classmethod + def register(cls, name: str): + """Register a command.""" + command_registry[name] = cls + + +# Registry of all commands +command_registry = {} + + +def execute_command(command_str: str, databases: dict[str, DatabaseConnection], butler: Butler | None) -> str: + """Parse a JSON formatted string into a command and execute it. + + Command format: + ``` + { + name: command name, + content: command content (usually a dict) + } + ``` + + Parameters + ---------- + command_str : + The JSON formatted command received from the user. + databases : + The database connections. + butler : + A conencted Butler. + """ + try: + command_dict = json.loads(command_str) + if not isinstance(command_dict, dict): + raise CommandParsingError(f"Could not generate a valid command from {command_str}") + except Exception as err: + logging.exception("Error converting command to JSON.") + return error_msg(err) + + try: + if "name" not in command_dict.keys(): + raise CommandParsingError("No command 'name' given") + + if command_dict["name"] not in command_registry.keys(): + raise CommandParsingError(f"Unrecognized command '{command_dict['name']}'") + + if "parameters" in command_dict: + parameters = command_dict["parameters"] + else: + parameters = {} + + command = command_registry[command_dict["name"]](**parameters) + + except Exception as err: + logging.exception("Error parsing command.") + return error_msg(CommandParsingError(f"'{err}' error while parsing command")) + + try: + command.execute(databases, butler) + except Exception as err: + logging.exception("Error executing command.") + return error_msg(CommandExecutionError(f"{err} error executing command.")) + + try: + result = command.to_json() + except Exception as err: + logging.exception("Error converting command response to JSON.") + return error_msg(CommandResponseError(f"{err} error converting command response to JSON.")) + + return result diff --git a/python/lsst/rubintv/analysis/service/database.py b/python/lsst/rubintv/analysis/service/database.py new file mode 100644 index 0000000..6ba8851 --- /dev/null +++ b/python/lsst/rubintv/analysis/service/database.py @@ -0,0 +1,271 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from dataclasses import dataclass +from typing import Sequence + +import sqlalchemy +from lsst.daf.butler import Butler + +from .command import BaseCommand, DatabaseConnection +from .query import Query + + +class UnrecognizedTableError(Exception): + """An error that occurs when a table name does not appear in the schema""" + + pass + + +def get_table_names(schema: dict) -> tuple[str, ...]: + """Given a schema, return a list of dataset names + + Parameters + ---------- + schema : + The schema for a database. + + Returns + ------- + result : + The names of all the tables in the database. + """ + return tuple(tbl["name"] for tbl in schema["tables"]) + + +def get_table_schema(schema: dict, table: str) -> dict: + """Get the schema for a table from the database schema + + Parameters + ---------- + schema: + The schema for a database. + table: + The name of the table in the database. + + Returns + ------- + result: + The schema for the table. + """ + tables = schema["tables"] + for _table in tables: + if _table["name"] == table: + return _table + raise UnrecognizedTableError("Could not find the table '{table}' in database") + + +def column_names_to_models(table: sqlalchemy.Table, columns: list[str]) -> list[sqlalchemy.Column]: + """Return the sqlalchemy model of a Table column for each column name. + + This method is used to generate a sqlalchemy query based on a `~Query`. + + Parameters + ---------- + table : + The name of the table in the database. + columns : + The names of the columns to generate models for. + + Returns + ------- + A list of sqlalchemy columns. + """ + models = [] + for column in columns: + models.append(getattr(table.columns, column)) + return models + + +def query_table( + table: str, + engine: sqlalchemy.engine.Engine, + columns: list[str] | None = None, + query: dict | None = None, +) -> Sequence[sqlalchemy.engine.row.Row]: + """Query a table and return the results + + Parameters + ---------- + engine : + The engine used to connect to the database. + table : + The table that is being queried. + columns : + The columns from the table to return. + If `columns` is ``None`` then all the columns + in the table are returned. + query : + A query used on the table. + If `query` is ``None`` then all the rows + in the query are returned. + + Returns + ------- + result : + A list of the rows that were returned by the query. + """ + metadata = sqlalchemy.MetaData() + _table = sqlalchemy.Table(table, metadata, autoload_with=engine) + + if columns is None: + _query = _table.select() + else: + _query = sqlalchemy.select(*column_names_to_models(_table, columns)) + + if query is not None: + _query = _query.where(Query.from_dict(query)(_table)) + + connection = engine.connect() + result = connection.execute(_query) + return result.fetchall() + + +def calculate_bounds(table: str, column: str, engine: sqlalchemy.engine.Engine) -> tuple[float, float]: + """Calculate the min, max for a column + + Parameters + ---------- + table : + The table that is being queried. + column : + The column to calculate the bounds of. + engine : + The engine used to connect to the database. + + Returns + ------- + result : + The ``(min, max)`` of the chosen column. + """ + metadata = sqlalchemy.MetaData() + _table = sqlalchemy.Table(table, metadata, autoload_with=engine) + _column = _table.columns[column] + + query = sqlalchemy.select((sqlalchemy.func.min(_column))) + connection = engine.connect() + result = connection.execute(query) + col_min = result.fetchone() + if col_min is not None: + col_min = col_min[0] + else: + raise ValueError(f"Could not calculate the min of column {column}") + + query = sqlalchemy.select((sqlalchemy.func.max(_column))) + connection = engine.connect() + result = connection.execute(query) + col_max = result.fetchone() + if col_max is not None: + col_max = col_max[0] + else: + raise ValueError(f"Could not calculate the min of column {column}") + + return col_min, col_max + + +@dataclass(kw_only=True) +class LoadColumnsCommand(BaseCommand): + """Load columns from a database table with an optional query. + + Attributes + ---------- + database : + The name of the database that the table is in. + table : + The table that the columns are loaded from. + columns : + Columns that are to be loaded. If `columns` is ``None`` + then all the columns in the `table` are loaded. + query : + Query used to select rows in the table. + If `query` is ``None`` then all the rows are loaded. + """ + + database: str + table: str + columns: list[str] | None = None + query: dict | None = None + response_type: str = "table columns" + + def build_contents(self, databases: dict[str, DatabaseConnection], butler: Butler | None) -> dict: + # Query the database to return the requested columns + database = databases[self.database] + index_column = get_table_schema(database.schema, self.table)["index_column"] + columns = self.columns + if columns is not None and index_column not in columns: + columns = [index_column] + columns + data = query_table( + table=self.table, + columns=columns, + query=self.query, + engine=database.engine, + ) + + if len(data) == 0: + # There is no column data to return + content: dict = { + "columns": columns, + "data": [], + } + else: + content = { + "columns": [column for column in data[0]._fields], + "data": [list(row) for row in data], + } + + return content + + +@dataclass(kw_only=True) +class CalculateBoundsCommand(BaseCommand): + """Calculate the bounds of a table column. + + Attributes + ---------- + database : + The name of the database that the table is in. + table : + The table that the columns are loaded from. + column : + The column to calculate the bounds of. + """ + + database: str + table: str + column: str + response_type: str = "column bounds" + + def build_contents(self, databases: dict[str, DatabaseConnection], butler: Butler | None) -> dict: + database = databases[self.database] + data = calculate_bounds( + table=self.table, + column=self.column, + engine=database.engine, + ) + return { + "column": self.column, + "bounds": data, + } + + +# Register the commands +LoadColumnsCommand.register("load columns") +CalculateBoundsCommand.register("get bounds") diff --git a/python/lsst/rubintv/analysis/service/query.py b/python/lsst/rubintv/analysis/service/query.py new file mode 100644 index 0000000..49d4fa4 --- /dev/null +++ b/python/lsst/rubintv/analysis/service/query.py @@ -0,0 +1,151 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import operator as op +from abc import ABC, abstractmethod +from typing import Any + +import sqlalchemy + + +class QueryError(Exception): + """An error that occurred during a query""" + + pass + + +class Query(ABC): + """Base class for constructing queries.""" + + @abstractmethod + def __call__(self, table: sqlalchemy.Table) -> sqlalchemy.sql.elements.BooleanClauseList: + """Run the query on a table. + + Parameters + ---------- + table : + The table to run the query on. + """ + pass + + @staticmethod + def from_dict(query_dict: dict[str, Any]) -> Query: + """Construct a query from a dictionary of parameters. + + Parameters + ---------- + query_dict : + Kwargs used to initialize the query. + There should only be two keys in this dict, + the ``name`` of the query and the ``content`` used + to initialize the query. + """ + try: + if query_dict["name"] == "EqualityQuery": + return EqualityQuery.from_dict(query_dict["content"]) + elif query_dict["name"] == "ParentQuery": + return ParentQuery.from_dict(query_dict["content"]) + except Exception: + raise QueryError("Failed to parse query.") + + raise QueryError("Unrecognized query type") + + +class EqualityQuery(Query): + """A query that compares a column to a static value. + + Parameters + ---------- + column : + The column used in the query. + operator : + The operator to use for the query. + value : + The value that the column is compared to. + """ + + def __init__( + self, + column: str, + operator: str, + value: Any, + ): + self.operator = operator + self.column = column + self.value = value + + def __call__(self, table: sqlalchemy.Table) -> sqlalchemy.sql.elements.BooleanClauseList: + column = table.columns[self.column] + + if self.operator in ("eq", "ne", "lt", "le", "gt", "ge"): + operator = getattr(op, self.operator) + return operator(column, self.value) + + if self.operator not in ("startswith", "endswith", "contains"): + raise QueryError(f"Unrecognized Equality operator {self.operator}") + + return getattr(column, self.operator)(self.value) + + @staticmethod + def from_dict(query_dict: dict[str, Any]) -> EqualityQuery: + return EqualityQuery(**query_dict) + + +class ParentQuery(Query): + """A query that uses a binary operation to combine other queries. + + Parameters + ---------- + children : + The child queries that are combined using the binary operator. + operator : + The operator that us used to combine the queries. + """ + + def __init__(self, children: list[Query], operator: str): + self.children = children + self.operator = operator + + def __call__(self, table: sqlalchemy.Table) -> sqlalchemy.sql.elements.BooleanClauseList: + child_results = [child(table) for child in self.children] + try: + if self.operator == "AND": + return sqlalchemy.and_(*child_results) + if self.operator == "OR": + return sqlalchemy.or_(*child_results) + if self.operator == "NOT": + return sqlalchemy.not_(*child_results) + if self.operator == "XOR": + return sqlalchemy.and_( + sqlalchemy.or_(*child_results), + sqlalchemy.not_(sqlalchemy.and_(*child_results)), + ) + except Exception: + raise QueryError("Error applying a boolean query statement.") + + @staticmethod + def from_dict(query_dict: dict[str, Any]) -> ParentQuery: + return ParentQuery( + children=[Query.from_dict(child) for child in query_dict["children"]], + operator=query_dict["operator"], + ) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f68ad48 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +numpy>=1.25.2 +scipy +matplotlib +pydantic +pyyaml +sqlalchemy +astropy +websocket-client +lsst-daf-butler + +# the following import is temporary while testing +tornado diff --git a/scripts/config.yaml b/scripts/config.yaml new file mode 100644 index 0000000..4e690ac --- /dev/null +++ b/scripts/config.yaml @@ -0,0 +1,8 @@ +--- +databases: + summitcdb: + schema: "/Users/fred3m/temp/visitDb/summit.yaml" + url: "sqlite:////Users/fred3m/temp/visitDb/summit.db" +#butler: +# repo: /repos/main +# skymap: hsc_rings_v1 diff --git a/scripts/mock_server.py b/scripts/mock_server.py new file mode 100644 index 0000000..814103e --- /dev/null +++ b/scripts/mock_server.py @@ -0,0 +1,259 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import uuid +from dataclasses import dataclass +from enum import Enum + +import tornado.httpserver +import tornado.ioloop +import tornado.web +import tornado.websocket + +# Default port and address to listen on +LISTEN_PORT = 2000 +LISTEN_ADDRESS = "localhost" + + +# ANSI color codes for printing to the terminal +ansi_colors = { + "black": "30", + "red": "31", + "green": "32", + "yellow": "33", + "blue": "34", + "magenta": "35", + "cyan": "36", + "white": "37", +} + + +def log(message, color, end="\033[31"): + """Print a message to the terminal in color. + + Parameters + ---------- + message : + The message to print. + color : + The color to print the message in. + end : + The color future messages should be printed in. + """ + _color = ansi_colors[color] + print(f"\033[{_color}m{message}{end}m") + + +class WorkerPodStatus(Enum): + """Status of a worker pod.""" + + IDLE = "idle" + BUSY = "busy" + + +class WebSocketHandler(tornado.websocket.WebSocketHandler): + """ + Handler that handles WebSocket connections + """ + + @classmethod + def urls(cls) -> list[tuple[str, type[tornado.web.RequestHandler], dict[str, str]]]: + """url to handle websocket connections. + + Websocket URLs should either be followed by 'worker' for worker pods + or client for clients. + """ + return [ + (r"/ws/([^/]+)", cls, {}), # Route/Handler/kwargs + ] + + def open(self, type: str) -> None: + """ + Client opens a websocket + + Parameters + ---------- + type : + The type of client that is connecting. + """ + self.client_id = str(uuid.uuid4()) + if type == "worker": + workers[self.client_id] = WorkerPod(self.client_id, self) + log(f"New worker {self.client_id} connected. Total workers: {len(workers)}", "blue") + if type == "client": + clients[self.client_id] = self + log(f"New client {self.client_id} connected. Total clients: {len(clients)}", "yellow") + + def on_message(self, message: str) -> None: + """ + Message received from a client or worker. + + Parameters + ---------- + message : + The message received from the client or worker. + """ + if self.client_id in clients: + log(f"Message received from {self.client_id}", "yellow") + client = clients[self.client_id] + + # Find an idle worker + idle_worker = None + for worker in workers.values(): + if worker.status == WorkerPodStatus.IDLE: + idle_worker = worker + break + + if idle_worker is None: + # No idle worker found, add to queue + queue.append(QueueItem(message, client)) + return + idle_worker.process(message, client) + return + + if self.client_id in workers: + worker = workers[self.client_id] + worker.on_finished(message) + log( + f"Message received from worker {self.client_id}. New status {worker.status}", + "blue", + ) + + # Check the queue for any outstanding jobs. + if len(queue) > 0: + queue_item = queue.pop(0) + worker.process(queue_item.message, queue_item.client) + return + + def on_close(self) -> None: + """ + Client closes the connection + """ + if self.client_id in clients: + del clients[self.client_id] + log(f"Client disconnected. Active clients: {len(clients)}", "yellow") + for worker in workers.values(): + if worker.connected_client == self: + worker.on_finished("Client disconnected") + break + if self.client_id in workers: + del workers[self.client_id] + log(f"Worker disconnected. Active workers: {len(workers)}", "blue") + + def check_origin(self, origin): + """ + Override the origin check if needed + """ + return True + + +class WorkerPod: + """State of a worker pod. + + Attributes + ---------- + id : + The id of the worker pod. + ws : + The websocket connection to the worker pod. + status : + The status of the worker pod. + connected_client : + The client that is connected to this worker pod. + """ + + status: WorkerPodStatus + connected_client: WebSocketHandler | None + + def __init__(self, id: str, ws: WebSocketHandler): + self.id = id + self.ws = ws + self.status = WorkerPodStatus.IDLE + self.connected_client = None + + def process(self, message: str, connected_client: WebSocketHandler): + """Process a message from a client. + + Parameters + ---------- + message : + The message to process. + connected_client : + The client that is connected to this worker pod. + """ + self.status = WorkerPodStatus.BUSY + self.connected_client = connected_client + log(f"Worker {self.id} processing message from client {connected_client.client_id}", "blue") + # Send the job to the worker pod + self.ws.write_message(message) + + def on_finished(self, message): + """Called when the worker pod has finished processing a message.""" + if ( + self.connected_client is not None + and self.connected_client.ws_connection is not None + and message != "Client disconnected" + ): + # Send the reply to the client that made the request. + self.connected_client.write_message(message) + else: + log(f"Worker {self.id} finished processing, but no client was connected.", "red") + self.status = WorkerPodStatus.IDLE + self.connected_client = None + + +@dataclass +class QueueItem: + """An item in the client queue. + + Attributes + ---------- + message : + The message to process. + client : + The client that is making a request. + """ + + message: str + client: WebSocketHandler + + +workers: dict[str, WorkerPod] = dict() # Keep track of connected worker pods +clients: dict[str, WebSocketHandler] = dict() # Keep track of connected clients +queue: list[QueueItem] = list() # Queue of messages to be processed + + +def main(): + # Create tornado application and supply URL routes + app = tornado.web.Application(WebSocketHandler.urls()) # type: ignore + + # Setup HTTP Server + http_server = tornado.httpserver.HTTPServer(app) + http_server.listen(LISTEN_PORT, LISTEN_ADDRESS) + + log(f"Listening on address: {LISTEN_ADDRESS}, {LISTEN_PORT}", "green") + + # Start IO/Event loop + tornado.ioloop.IOLoop.instance().start() + + +if __name__ == "__main__": + main() diff --git a/scripts/rubintv_worker.py b/scripts/rubintv_worker.py new file mode 100644 index 0000000..ad03071 --- /dev/null +++ b/scripts/rubintv_worker.py @@ -0,0 +1,54 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import argparse +import os +import pathlib + +import yaml +from lsst.rubintv.analysis.service.client import run_worker + +default_config = os.path.join(pathlib.Path(__file__).parent.absolute(), "config.yaml") + + +def main(): + parser = argparse.ArgumentParser(description="Initialize a new RubinTV worker.") + parser.add_argument( + "-a", "--address", default="localhost", type=str, help="Address of the rubinTV web app." + ) + parser.add_argument( + "-p", "--port", default=2000, type=int, help="Port of the rubinTV web app websockets." + ) + parser.add_argument( + "-c", "--config", default=default_config, type=str, help="Location of the configuration file." + ) + args = parser.parse_args() + + # Load the configuration file + with open(args.config, "r") as file: + config = yaml.safe_load(file) + + # Run the client and connect to rubinTV via websockets + run_worker(args.address, args.port, config) + + +if __name__ == "__main__": + main() diff --git a/setup.cfg b/setup.cfg index bf28c0a..10e0f70 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,7 @@ [flake8] max-line-length = 110 max-doc-length = 79 -ignore = E133, E226, E228, N802, N803, N806, N812, N813, N815, N816, W504 +ignore = W503, E203 exclude = bin, doc, diff --git a/tests/schema.yaml b/tests/schema.yaml new file mode 100644 index 0000000..835499d --- /dev/null +++ b/tests/schema.yaml @@ -0,0 +1,42 @@ +--- +name: testdb +"@id": "#test_db" +description: Small database for testing the package +tables: + - name: ExposureInfo + index_column: exposure_id + columns: + - name: exposure_id + datatype: long + description: Unique identifier of an exposure. + - name: seq_num + datatype: long + description: Sequence number + - name: ra + datatype: double + unit: degree + description: RA of focal plane center. + - name: dec + datatype: double + unit: degree + description: Declination of focal plane center + - name: expTime + datatype: double + description: Spatially-averaged duration of exposure, accurate to 10ms. + - name: physical_filter + datatype: char + description: ID of physical filter, + the filter associated with a particular instrument. + - name: obsNight + datatype: date + description: The night of the observation. This is different than the + observation date, as this is the night that the observations started, + so for observations after midnight obsStart and obsNight will be + different days. + - name: obsStart + datatype: datetime + description: Start time of the exposure at the fiducial center + of the focal plane array, TAI, accurate to 10ms. + - name: obsStartMJD + datatype: double + description: Start of the exposure in MJD, TAI, accurate to 10ms. diff --git a/tests/test_command.py b/tests/test_command.py new file mode 100644 index 0000000..bab82c7 --- /dev/null +++ b/tests/test_command.py @@ -0,0 +1,210 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import json +import os +import tempfile + +import lsst.rubintv.analysis.service as lras +import sqlalchemy +import utils +import yaml + + +class TestCommand(utils.RasTestCase): + def setUp(self): + path = os.path.dirname(__file__) + yaml_filename = os.path.join(path, "schema.yaml") + + with open(yaml_filename) as file: + schema = yaml.safe_load(file) + db_file = tempfile.NamedTemporaryFile(delete=False) + utils.create_database(schema, db_file.name) + self.db_file = db_file + self.db_filename = db_file.name + self.schema = schema + + # Load the database connection information + self.databases = { + "testdb": lras.command.DatabaseConnection( + schema=schema, engine=sqlalchemy.create_engine("sqlite:///" + db_file.name) + ) + } + + # Set up the sqlalchemy connection + self.engine = sqlalchemy.create_engine("sqlite:///" + db_file.name) + + def tearDown(self) -> None: + self.db_file.close() + os.remove(self.db_file.name) + + def execute_command(self, command: dict, response_type: str) -> dict: + command_json = json.dumps(command) + response = lras.command.execute_command(command_json, self.databases, None) + result = json.loads(response) + self.assertEqual(result["type"], response_type) + return result["content"] + + +class TestCalculateBoundsCommand(TestCommand): + def test_calculate_bounds_command(self): + command = { + "name": "get bounds", + "parameters": { + "database": "testdb", + "table": "ExposureInfo", + "column": "dec", + }, + } + content = self.execute_command(command, "column bounds") + self.assertEqual(content["column"], "dec") + self.assertListEqual(content["bounds"], [-40, 50]) + + +class TestLoadColumnsCommand(TestCommand): + def test_load_full_dataset(self): + command = {"name": "load columns", "parameters": {"database": "testdb", "table": "ExposureInfo"}} + + content = self.execute_command(command, "table columns") + data = content["data"] + + truth = utils.ap_table_to_list(utils.get_test_data()) + + self.assertDataTableEqual(data, truth) + + def test_load_full_columns(self): + command = { + "name": "load columns", + "parameters": { + "database": "testdb", + "table": "ExposureInfo", + "columns": [ + "ra", + "dec", + ], + }, + } + + content = self.execute_command(command, "table columns") + columns = content["columns"] + data = content["data"] + + truth = utils.get_test_data()["exposure_id", "ra", "dec"] + truth_data = utils.ap_table_to_list(truth) + + self.assertTupleEqual(tuple(columns), tuple(truth.columns)) + self.assertDataTableEqual(data, truth_data) + + def test_load_columns_with_query(self): + command = { + "name": "load columns", + "parameters": { + "database": "testdb", + "table": "ExposureInfo", + "columns": [ + "exposure_id", + "ra", + "dec", + ], + "query": { + "name": "EqualityQuery", + "content": { + "column": "expTime", + "operator": "eq", + "value": 30, + }, + }, + }, + } + + content = self.execute_command(command, "table columns") + columns = content["columns"] + data = content["data"] + + truth = utils.get_test_data()["exposure_id", "ra", "dec"] + # Select rows with expTime = 30 + truth = truth[[True, True, False, False, False, True, True, True, False, False]] + truth_data = utils.ap_table_to_list(truth) + + self.assertTupleEqual(tuple(columns), tuple(truth.columns)) + self.assertDataTableEqual(data, truth_data) + + +class TestCommandErrors(TestCommand): + def check_error_response(self, content: dict, error: str, description: str | None = None): + self.assertEqual(content["error"], error) + if description is not None: + self.assertEqual(content["description"], description) + + def test_errors(self): + # Command cannot be decoded as JSON dict + content = self.execute_command("{'test': [1,2,3,0004,}", "error") + self.check_error_response(content, "parsing error") + + # Command does not contain a "name" + command = {"content": {}} + content = self.execute_command(command, "error") + self.check_error_response( + content, + "parsing error", + "'No command 'name' given' error while parsing command", + ) + + # Command has an invalid name + command = {"name": "invalid name"} + content = self.execute_command(command, "error") + self.check_error_response( + content, + "parsing error", + "'Unrecognized command 'invalid name'' error while parsing command", + ) + + # Command has no parameters + command = {"name": "get bounds"} + content = self.execute_command(command, "error") + self.check_error_response( + content, + "parsing error", + ) + + # Command has invalid parameters + command = { + "name": "get bounds", + "parameters": { + "a": 1, + }, + } + content = self.execute_command(command, "error") + self.check_error_response( + content, + "parsing error", + ) + + # Command execution failed (table name does not exist) + command = { + "name": "get bounds", + "parameters": {"database": "testdb", "table": "InvalidTable", "column": "invalid_column"}, + } + content = self.execute_command(command, "error") + self.check_error_response( + content, + "execution error", + ) diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..94d6914 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,103 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import os +import tempfile +from unittest import TestCase + +import lsst.rubintv.analysis.service as lras +import sqlalchemy +import utils +import yaml + + +class TestDatabase(TestCase): + def setUp(self): + path = os.path.dirname(__file__) + yaml_filename = os.path.join(path, "schema.yaml") + + with open(yaml_filename) as file: + schema = yaml.safe_load(file) + db_file = tempfile.NamedTemporaryFile(delete=False) + utils.create_database(schema, db_file.name) + self.db_file = db_file + self.db_filename = db_file.name + self.schema = schema + + # Set up the sqlalchemy connection + self.engine = sqlalchemy.create_engine("sqlite:///" + db_file.name) + + def tearDown(self) -> None: + self.db_file.close() + os.remove(self.db_file.name) + + def test_get_table_names(self): + table_names = lras.database.get_table_names(self.schema) + self.assertTupleEqual(table_names, ("ExposureInfo",)) + + def test_get_table_schema(self): + schema = lras.database.get_table_schema(self.schema, "ExposureInfo") + self.assertEqual(schema["name"], "ExposureInfo") + + columns = [ + "exposure_id", + "seq_num", + "ra", + "dec", + "expTime", + "physical_filter", + "obsNight", + "obsStart", + "obsStartMJD", + ] + for n, column in enumerate(schema["columns"]): + self.assertEqual(column["name"], columns[n]) + + def test_query_full_table(self): + truth_table = utils.get_test_data() + truth = utils.ap_table_to_list(truth_table) + + data = lras.database.query_table("ExposureInfo", engine=self.engine) + print(data) + + self.assertListEqual(list(data[0]._fields), list(truth_table.columns)) + + for n in range(len(truth)): + true_row = tuple(truth[n]) + row = tuple(data[n]) + self.assertTupleEqual(row, true_row) + + def test_query_columns(self): + truth = utils.get_test_data() + truth = utils.ap_table_to_list(truth["ra", "dec"]) + + data = lras.database.query_table("ExposureInfo", columns=["ra", "dec"], engine=self.engine) + + self.assertListEqual(list(data[0]._fields), ["ra", "dec"]) + + for n in range(len(truth)): + true_row = tuple(truth[n]) + row = tuple(data[n]) + self.assertTupleEqual(row, true_row) + + def test_calculate_bounds(self): + result = lras.database.calculate_bounds("ExposureInfo", "dec", self.engine) + self.assertTupleEqual(result, (-40, 50)) diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 0000000..a125218 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,280 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import os +import tempfile + +import lsst.rubintv.analysis.service as lras +import sqlalchemy +import utils +import yaml + + +class TestQuery(utils.RasTestCase): + def setUp(self): + path = os.path.dirname(__file__) + yaml_filename = os.path.join(path, "schema.yaml") + + with open(yaml_filename) as file: + schema = yaml.safe_load(file) + db_file = tempfile.NamedTemporaryFile(delete=False) + utils.create_database(schema, db_file.name) + self.db_file = db_file + self.db_filename = db_file.name + self.schema = schema + + # Set up the sqlalchemy connection + self.engine = sqlalchemy.create_engine("sqlite:///" + db_file.name) + self.metadata = sqlalchemy.MetaData() + self.table = sqlalchemy.Table("ExposureInfo", self.metadata, autoload_with=self.engine) + + def tearDown(self) -> None: + self.db_file.close() + os.remove(self.db_file.name) + + def test_equality(self): + table = self.table + column = table.columns.dec + + value = 0 + truth_dict = { + "eq": column == value, + "ne": column != value, + "lt": column < value, + "le": column <= value, + "gt": column > value, + "ge": column >= value, + } + + for operator, truth in truth_dict.items(): + self.assertTrue(lras.query.EqualityQuery("dec", operator, value)(table).compare(truth)) + + def test_query(self): + table = self.table + + # dec > 0 + query = lras.query.EqualityQuery("dec", "gt", 0) + result = query(table) + self.assertTrue(result.compare(table.columns.dec > 0)) + + # dec < 0 and ra > 60 + query = lras.query.ParentQuery( + operator="AND", + children=[ + lras.query.EqualityQuery("dec", "lt", 0), + lras.query.EqualityQuery("ra", "gt", 60), + ], + ) + result = query(table) + truth = sqlalchemy.and_( + table.columns.dec < 0, + table.columns.ra > 60, + ) + self.assertTrue(result.compare(truth)) + + # Check queries that are unequal to verify that they don't work + result = query(table) + truth = sqlalchemy.and_( + table.columns.dec < 0, + table.columns.ra > 70, + ) + self.assertFalse(result.compare(truth)) + + def test_database_query(self): + data = utils.get_test_data() + + # dec > 0 (and is not None) + query1 = { + "name": "EqualityQuery", + "content": { + "column": "dec", + "operator": "gt", + "value": 0, + }, + } + # ra > 60 (and is not None) + query2 = { + "name": "EqualityQuery", + "content": { + "column": "ra", + "operator": "gt", + "value": 60, + }, + } + + # Test 1: dec > 0 (and is not None) + query = query1 + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + truth = data[[False, False, False, False, False, True, False, True, True, True]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + # Test 2: dec > 0 and ra > 60 (and neither is None) + query = { + "name": "ParentQuery", + "content": { + "operator": "AND", + "children": [query1, query2], + }, + } + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + truth = data[[False, False, False, False, False, False, False, False, True, True]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + # Test 3: dec <= 0 or ra > 60 (and neither is None) + query = { + "name": "ParentQuery", + "content": { + "operator": "OR", + "children": [ + { + "name": "ParentQuery", + "content": { + "operator": "NOT", + "children": [query1], + }, + }, + query2, + ], + }, + } + + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + truth = data[[True, True, False, True, True, False, True, False, True, True]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + # Test 4: dec > 0 XOR ra > 60 + query = { + "name": "ParentQuery", + "content": { + "operator": "XOR", + "children": [query1, query2], + }, + } + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + truth = data[[False, False, False, False, False, True, False, False, False, False]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + def test_database_string_query(self): + data = utils.get_test_data() + + # Test equality + query = { + "name": "EqualityQuery", + "content": { + "column": "physical_filter", + "operator": "eq", + "value": "DECam r-band", + }, + } + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + truth = data[[False, False, False, False, False, False, True, False, False, False]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + # Test "startswith" + query = { + "name": "EqualityQuery", + "content": { + "column": "physical_filter", + "operator": "startswith", + "value": "DECam", + }, + } + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + truth = data[[False, False, False, False, False, True, True, True, True, True]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + # Test "endswith" + query = { + "name": "EqualityQuery", + "content": { + "column": "physical_filter", + "operator": "endswith", + "value": "r-band", + }, + } + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + truth = data[[False, True, False, False, False, False, True, False, False, False]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + # Test "like" + query = { + "name": "EqualityQuery", + "content": { + "column": "physical_filter", + "operator": "contains", + "value": "T r", + }, + } + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + truth = data[[False, True, False, False, False, False, False, False, False, False]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + def test_database_datatime_query(self): + data = utils.get_test_data() + + # Test < + query1 = { + "name": "EqualityQuery", + "content": { + "column": "obsStart", + "operator": "lt", + "value": "2023-05-19 23:23:23", + }, + } + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query1) + truth = data[[True, True, True, False, False, True, True, True, True, True]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + # Test > + query2 = { + "name": "EqualityQuery", + "content": { + "column": "obsStart", + "operator": "gt", + "value": "2023-05-01 23:23:23", + }, + } + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query2) + truth = data[[True, True, True, True, True, False, False, False, False, False]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + # Test in range + query3 = { + "name": "ParentQuery", + "content": { + "operator": "AND", + "children": [query1, query2], + }, + } + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query3) + truth = data[[True, True, True, False, False, False, False, False, False, False]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..5b78e5b --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,197 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import sqlite3 +from unittest import TestCase + +from astropy.table import Table as ApTable +from astropy.time import Time + +# Convert visit DB datatypes to sqlite3 datatypes +datatype_transform = { + "int": "integer", + "long": "integer", + "double": "real", + "float": "real", + "char": "text", + "date": "text", + "datetime": "text", +} + + +def create_table(cursor: sqlite3.Cursor, tbl_name: str, schema: dict): + """Create a table in an sqlite database. + + Parameters + ---------- + cursor : + The cursor associated with the database connection. + tbl_name : + The name of the table to create. + schema : + The schema of the table. + """ + command = f"CREATE TABLE {tbl_name}(\n" + for field in schema: + command += f' {field["name"]} {datatype_transform[field["datatype"]]},\n' + command = command[:-2] + "\n);" + cursor.execute(command) + + +def get_test_data_dict() -> dict: + """Get a dictionary containing the test data""" + obs_start = [ + "2023-05-19 20:20:20", + "2023-05-19 21:21:21", + "2023-05-19 22:22:22", + "2023-05-19 23:23:23", + "2023-05-20 00:00:00", + "2023-02-14 22:22:22", + "2023-02-14 23:23:23", + "2023-02-14 00:00:00", + "2023-02-14 01:01:01", + "2023-02-14 02:02:02", + ] + + obs_start_mjd = [Time(time).mjd for time in obs_start] + + return { + "exposure_id": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18], + "seq_num": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + "ra": [10, 20, None, 40, 50, 60, 70, None, 90, 100], + "dec": [-40, -30, None, -10, 0, 10, None, 30, 40, 50], + "expTime": [30, 30, 10, 15, 15, 30, 30, 30, 15, 20], + "physical_filter": [ + "LSST g-band", + "LSST r-band", + "LSST i-band", + "LSST z-band", + "LSST y-band", + "DECam g-band", + "DECam r-band", + "DECam i-band", + "DECam z-band", + "DECam y-band", + ], + "obsNight": [ + "2023-05-19", + "2023-05-19", + "2023-05-19", + "2023-05-19", + "2023-05-19", + "2023-02-14", + "2023-02-14", + "2023-02-14", + "2023-02-14", + "2023-02-14", + ], + "obsStart": obs_start, + "obsStartMJD": obs_start_mjd, + } + + +def get_test_data() -> ApTable: + """Generate data for the test database""" + data_dict = get_test_data_dict() + + table = ApTable(list(data_dict.values()), names=list(data_dict.keys())) + return table + + +def ap_table_to_list(data: ApTable) -> list: + """Convert an astropy Table into a list of tuples.""" + rows = [] + for row in data: + rows.append(tuple(row)) + return rows + + +def create_database(schema: dict, db_filename: str): + """Create the test database""" + tbl_name = "ExposureInfo" + connection = sqlite3.connect(db_filename) + cursor = connection.cursor() + + create_table(cursor, tbl_name, schema["tables"][0]["columns"]) + + data = get_test_data_dict() + + for n in range(len(data["exposure_id"])): + row = tuple(data[key][n] for key in data.keys()) + value_str = "?, " * (len(row) - 1) + "?" + command = f"INSERT INTO {tbl_name} VALUES({value_str});" + cursor.execute(command, row) + connection.commit() + cursor.close() + + +class TableMismatchError(AssertionError): + pass + + +class RasTestCase(TestCase): + """Base class for tests in this package + + For now this only includes methods to check the + database results, but in the future other checks + might be put in place. + """ + + @staticmethod + def get_data_table_indices(table: list[tuple]) -> list[int]: + """Get the index for each rom in the data table. + + Parameters + ---------- + table : + The table containing the data. + + Returns + ------- + result : + The index for each row in the table. + """ + # Return the seq_num as an index + return [row[1] for row in table] + + def assertDataTableEqual(self, result, truth): + """Check if two data tables are equal. + + Parameters + ---------- + result : + The result generated by the test that is checked. + truth : + The expected value of the test. + """ + if len(result) != len(truth): + msg = "Data tables have a different number of rows: " + msg += f"indices: [{self.get_data_table_indices(result)}], [{self.get_data_table_indices(truth)}]" + raise TableMismatchError(msg) + try: + for n in range(len(truth)): + true_row = tuple(truth[n]) + row = tuple(result[n]) + self.assertTupleEqual(row, true_row) + except AssertionError: + msg = "Mismatched tables: " + msg += f"indices: [{self.get_data_table_indices(result)}], [{self.get_data_table_indices(truth)}]" + raise TableMismatchError(msg) From 7c4409e3580cc27796ab6b528645f41533971bdf Mon Sep 17 00:00:00 2001 From: Fred Moolekamp Date: Mon, 23 Oct 2023 16:36:04 -0400 Subject: [PATCH 2/3] Responses to reviewer comments (rebase before merging) --- .../lsst/rubintv/analysis/service/__init__.py | 2 +- .../lsst/rubintv/analysis/service/client.py | 90 +++++++------ .../lsst/rubintv/analysis/service/command.py | 21 +-- .../lsst/rubintv/analysis/service/database.py | 2 +- python/lsst/rubintv/analysis/service/query.py | 29 ++-- python/lsst/rubintv/analysis/service/utils.py | 40 ++++++ scripts/mock_server.py | 125 +++++++++--------- scripts/rubintv_worker.py | 5 +- 8 files changed, 177 insertions(+), 137 deletions(-) create mode 100644 python/lsst/rubintv/analysis/service/utils.py diff --git a/python/lsst/rubintv/analysis/service/__init__.py b/python/lsst/rubintv/analysis/service/__init__.py index e144f74..ae91fbd 100644 --- a/python/lsst/rubintv/analysis/service/__init__.py +++ b/python/lsst/rubintv/analysis/service/__init__.py @@ -1 +1 @@ -from . import command, database, query +from . import command, database, query, utils diff --git a/python/lsst/rubintv/analysis/service/client.py b/python/lsst/rubintv/analysis/service/client.py index 9504b0b..38fba82 100644 --- a/python/lsst/rubintv/analysis/service/client.py +++ b/python/lsst/rubintv/analysis/service/client.py @@ -26,56 +26,64 @@ from websocket import WebSocketApp from .command import DatabaseConnection, execute_command +from .utils import printc, Colors logger = logging.getLogger("lsst.rubintv.analysis.service.client") -def on_error(ws: WebSocketApp, error: str) -> None: - """Error received from the server.""" - print(f"\033[91mError: {error}\033[0m") +class Worker: + def __init__(self, address: str, port: int, connection_info: dict[str, dict]): + self._address = address + self._port = port + self._connection_info = connection_info + def on_error(self, ws: WebSocketApp, error: str) -> None: + """Error received from the server.""" + printc(f"Error: {error}", color=Colors.BRIGHT_RED) -def on_close(ws: WebSocketApp, close_status_code: str, close_msg: str) -> None: - """Connection closed by the server.""" - print("\033[93mConnection closed\033[0m") + def on_close(self, ws: WebSocketApp, close_status_code: str, close_msg: str) -> None: + """Connection closed by the server.""" + printc("Connection closed", Colors.BRIGHT_YELLOW) + def run(self) -> None: + """Run the worker and connect to the rubinTV server. -def run_worker(address: str, port: int, connection_info: dict[str, dict]) -> None: - """Run the worker and connect to the rubinTV server. + Parameters + ---------- + address : + Address of the rubinTV web app. + port : + Port of the rubinTV web app websockets. + connection_info : + Connections . + """ + # Load the database connection information + databases: dict[str, DatabaseConnection] = {} - Parameters - ---------- - address : - Address of the rubinTV web app. - port : - Port of the rubinTV web app websockets. - connection_info : - Connections . - """ - # Load the database connection information - databases: dict[str, DatabaseConnection] = {} + for name, info in self._connection_info["databases"].items(): + with open(info["schema"], "r") as file: + engine = sqlalchemy.create_engine(info["url"]) + schema = yaml.safe_load(file) + databases[name] = DatabaseConnection(schema=schema, engine=engine) - for name, info in connection_info["databases"].items(): - with open(info["schema"], "r") as file: - engine = sqlalchemy.create_engine(info["url"]) - schema = yaml.safe_load(file) - databases[name] = DatabaseConnection(schema=schema, engine=engine) + # Load the Butler (if one is available) + butler: Butler | None = None + if "butler" in self._connection_info: + repo = self._connection_info["butler"].pop("repo") + butler = Butler(repo, **self._connection_info["butler"]) - # Load the Butler (if one is available) - butler: Butler | None = None - if "butler" in connection_info: - repo = connection_info["butler"].pop("repo") - butler = Butler(repo, **connection_info["butler"]) + def on_message(ws: WebSocketApp, message: str) -> None: + """Message received from the server.""" + response = execute_command(message, databases, butler) + ws.send(response) - def on_message(ws: WebSocketApp, message: str) -> None: - """Message received from the server.""" - response = execute_command(message, databases, butler) - ws.send(response) - - print(f"\033[92mConnecting to rubinTV at {address}:{port}\033[0m") - # Connect to the WebSocket server - ws = WebSocketApp( - f"ws://{address}:{port}/ws/worker", on_message=on_message, on_error=on_error, on_close=on_close - ) - ws.run_forever() - ws.close() + printc(f"Connecting to rubinTV at {self._address}:{self._port}", Colors.BRIGHT_GREEN) + # Connect to the WebSocket server + ws = WebSocketApp( + f"ws://{self._address}:{self._port}/ws/worker", + on_message=on_message, + on_error=self.on_error, + on_close=self.on_close, + ) + ws.run_forever() + ws.close() diff --git a/python/lsst/rubintv/analysis/service/command.py b/python/lsst/rubintv/analysis/service/command.py index c90e3ab..776ea2b 100644 --- a/python/lsst/rubintv/analysis/service/command.py +++ b/python/lsst/rubintv/analysis/service/command.py @@ -141,6 +141,7 @@ class BaseCommand(ABC): This should be unique for each command. """ + command_registry = {} result: dict | None = None response_type: str @@ -153,7 +154,7 @@ def build_contents(self, databases: dict[str, DatabaseConnection], butler: Butle databases : The database connections. butler : - A conencted Butler. + A connected Butler. Returns ------- @@ -187,11 +188,7 @@ def to_json(self): @classmethod def register(cls, name: str): """Register a command.""" - command_registry[name] = cls - - -# Registry of all commands -command_registry = {} + BaseCommand.command_registry[name] = cls def execute_command(command_str: str, databases: dict[str, DatabaseConnection], butler: Butler | None) -> str: @@ -212,7 +209,7 @@ def execute_command(command_str: str, databases: dict[str, DatabaseConnection], databases : The database connections. butler : - A conencted Butler. + A connected Butler. """ try: command_dict = json.loads(command_str) @@ -226,15 +223,11 @@ def execute_command(command_str: str, databases: dict[str, DatabaseConnection], if "name" not in command_dict.keys(): raise CommandParsingError("No command 'name' given") - if command_dict["name"] not in command_registry.keys(): + if command_dict["name"] not in BaseCommand.command_registry.keys(): raise CommandParsingError(f"Unrecognized command '{command_dict['name']}'") - if "parameters" in command_dict: - parameters = command_dict["parameters"] - else: - parameters = {} - - command = command_registry[command_dict["name"]](**parameters) + parameters = command_dict.get("parameters", {}) + command = BaseCommand.command_registry[command_dict["name"]](**parameters) except Exception as err: logging.exception("Error parsing command.") diff --git a/python/lsst/rubintv/analysis/service/database.py b/python/lsst/rubintv/analysis/service/database.py index 6ba8851..eb06f48 100644 --- a/python/lsst/rubintv/analysis/service/database.py +++ b/python/lsst/rubintv/analysis/service/database.py @@ -219,7 +219,7 @@ def build_contents(self, databases: dict[str, DatabaseConnection], butler: Butle engine=database.engine, ) - if len(data) == 0: + if not data: # There is no column data to return content: dict = { "columns": columns, diff --git a/python/lsst/rubintv/analysis/service/query.py b/python/lsst/rubintv/analysis/service/query.py index 49d4fa4..4f6c0bf 100644 --- a/python/lsst/rubintv/analysis/service/query.py +++ b/python/lsst/rubintv/analysis/service/query.py @@ -123,23 +123,24 @@ class ParentQuery(Query): """ def __init__(self, children: list[Query], operator: str): - self.children = children - self.operator = operator + self._children = children + self._operator = operator def __call__(self, table: sqlalchemy.Table) -> sqlalchemy.sql.elements.BooleanClauseList: - child_results = [child(table) for child in self.children] + child_results = [child(table) for child in self._children] try: - if self.operator == "AND": - return sqlalchemy.and_(*child_results) - if self.operator == "OR": - return sqlalchemy.or_(*child_results) - if self.operator == "NOT": - return sqlalchemy.not_(*child_results) - if self.operator == "XOR": - return sqlalchemy.and_( - sqlalchemy.or_(*child_results), - sqlalchemy.not_(sqlalchemy.and_(*child_results)), - ) + match self._operator: + case "AND": + return sqlalchemy.and_(*child_results) + case "OR": + return sqlalchemy.or_(*child_results) + case "NOT": + return sqlalchemy.not_(*child_results) + case "XOR": + return sqlalchemy.and_( + sqlalchemy.or_(*child_results), + sqlalchemy.not_(sqlalchemy.and_(*child_results)), + ) except Exception: raise QueryError("Error applying a boolean query statement.") diff --git a/python/lsst/rubintv/analysis/service/utils.py b/python/lsst/rubintv/analysis/service/utils.py new file mode 100644 index 0000000..806215b --- /dev/null +++ b/python/lsst/rubintv/analysis/service/utils.py @@ -0,0 +1,40 @@ +from enum import Enum + + +# ANSI color codes for printing to the terminal +class Colors(Enum): + RESET = 0 + BLACK = 30 + RED = 31 + GREEN = 32 + YELLOW = 33 + BLUE = 34 + MAGENTA = 35 + CYAN = 36 + WHITE = 37 + DEFAULT = 39 + BRIGHT_BLACK = 90 + BRIGHT_RED = 91 + BRIGHT_GREEN = 92 + BRIGHT_YELLOW = 93 + BRIGHT_BLUE = 94 + BRIGHT_MAGENTA = 95 + BRIGHT_CYAN = 96 + BRIGHT_WHITE = 97 + + +def printc(message: str, color: Colors, end_color: Colors = Colors.RESET): + """Print a message to the terminal in color. + + After printing reset the color by default. + + Parameters + ---------- + message : + The message to print. + color : + The color to print the message in. + end : + The color future messages should be printed in. + """ + print(f"\033[{color.value}m{message}\033[{end_color.value}m") diff --git a/scripts/mock_server.py b/scripts/mock_server.py index 814103e..eb3797c 100644 --- a/scripts/mock_server.py +++ b/scripts/mock_server.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from __future__ import annotations + import uuid from dataclasses import dataclass from enum import Enum @@ -28,40 +30,13 @@ import tornado.web import tornado.websocket +from lsst.rubintv.analysis.service.utils import printc, Colors + # Default port and address to listen on LISTEN_PORT = 2000 LISTEN_ADDRESS = "localhost" -# ANSI color codes for printing to the terminal -ansi_colors = { - "black": "30", - "red": "31", - "green": "32", - "yellow": "33", - "blue": "34", - "magenta": "35", - "cyan": "36", - "white": "37", -} - - -def log(message, color, end="\033[31"): - """Print a message to the terminal in color. - - Parameters - ---------- - message : - The message to print. - color : - The color to print the message in. - end : - The color future messages should be printed in. - """ - _color = ansi_colors[color] - print(f"\033[{_color}m{message}{end}m") - - class WorkerPodStatus(Enum): """Status of a worker pod.""" @@ -74,6 +49,10 @@ class WebSocketHandler(tornado.websocket.WebSocketHandler): Handler that handles WebSocket connections """ + workers: dict[str, WorkerPod] = dict() # Keep track of connected worker pods + clients: dict[str, WebSocketHandler] = dict() # Keep track of connected clients + queue: list[QueueItem] = list() # Queue of messages to be processed + @classmethod def urls(cls) -> list[tuple[str, type[tornado.web.RequestHandler], dict[str, str]]]: """url to handle websocket connections. @@ -85,7 +64,7 @@ def urls(cls) -> list[tuple[str, type[tornado.web.RequestHandler], dict[str, str (r"/ws/([^/]+)", cls, {}), # Route/Handler/kwargs ] - def open(self, type: str) -> None: + def open(self, client_type: str) -> None: """ Client opens a websocket @@ -95,12 +74,20 @@ def open(self, type: str) -> None: The type of client that is connecting. """ self.client_id = str(uuid.uuid4()) - if type == "worker": - workers[self.client_id] = WorkerPod(self.client_id, self) - log(f"New worker {self.client_id} connected. Total workers: {len(workers)}", "blue") - if type == "client": - clients[self.client_id] = self - log(f"New client {self.client_id} connected. Total clients: {len(clients)}", "yellow") + if client_type == "worker": + WebSocketHandler.workers[self.client_id] = WorkerPod(self.client_id, self) + printc( + f"New worker {self.client_id} connected. Total workers: {len(WebSocketHandler.workers)}", + Colors.BLUE, + Colors.RED, + ) + if client_type == "client": + WebSocketHandler.clients[self.client_id] = self + printc( + f"New client {self.client_id} connected. Total clients: {len(WebSocketHandler.clients)}", + Colors.YELLOW, + Colors.RED, + ) def on_message(self, message: str) -> None: """ @@ -111,35 +98,36 @@ def on_message(self, message: str) -> None: message : The message received from the client or worker. """ - if self.client_id in clients: - log(f"Message received from {self.client_id}", "yellow") - client = clients[self.client_id] + if self.client_id in WebSocketHandler.clients: + printc(f"Message received from {self.client_id}", Colors.YELLOW, Colors.RED) + client = WebSocketHandler.clients[self.client_id] # Find an idle worker idle_worker = None - for worker in workers.values(): + for worker in WebSocketHandler.workers.values(): if worker.status == WorkerPodStatus.IDLE: idle_worker = worker break if idle_worker is None: # No idle worker found, add to queue - queue.append(QueueItem(message, client)) + WebSocketHandler.queue.append(QueueItem(message, client)) return idle_worker.process(message, client) return - if self.client_id in workers: - worker = workers[self.client_id] + if self.client_id in WebSocketHandler.workers: + worker = WebSocketHandler.workers[self.client_id] worker.on_finished(message) - log( + printc( f"Message received from worker {self.client_id}. New status {worker.status}", - "blue", + Colors.BLUE, + Colors.RED, ) # Check the queue for any outstanding jobs. - if len(queue) > 0: - queue_item = queue.pop(0) + if len(WebSocketHandler.queue) > 0: + queue_item = WebSocketHandler.queue.pop(0) worker.process(queue_item.message, queue_item.client) return @@ -147,16 +135,24 @@ def on_close(self) -> None: """ Client closes the connection """ - if self.client_id in clients: - del clients[self.client_id] - log(f"Client disconnected. Active clients: {len(clients)}", "yellow") - for worker in workers.values(): + if self.client_id in WebSocketHandler.clients: + del WebSocketHandler.clients[self.client_id] + printc( + f"Client disconnected. Active clients: {len(WebSocketHandler.clients)}", + Colors.YELLOW, + Colors.RED, + ) + for worker in WebSocketHandler.workers.values(): if worker.connected_client == self: worker.on_finished("Client disconnected") break - if self.client_id in workers: - del workers[self.client_id] - log(f"Worker disconnected. Active workers: {len(workers)}", "blue") + if self.client_id in WebSocketHandler.workers: + del WebSocketHandler.workers[self.client_id] + printc( + f"Worker disconnected. Active workers: {len(WebSocketHandler.workers)}", + Colors.BLUE, + Colors.RED, + ) def check_origin(self, origin): """ @@ -183,8 +179,8 @@ class WorkerPod: status: WorkerPodStatus connected_client: WebSocketHandler | None - def __init__(self, id: str, ws: WebSocketHandler): - self.id = id + def __init__(self, wid: str, ws: WebSocketHandler): + self.wid = wid self.ws = ws self.status = WorkerPodStatus.IDLE self.connected_client = None @@ -201,7 +197,11 @@ def process(self, message: str, connected_client: WebSocketHandler): """ self.status = WorkerPodStatus.BUSY self.connected_client = connected_client - log(f"Worker {self.id} processing message from client {connected_client.client_id}", "blue") + printc( + f"Worker {self.wid} processing message from client {connected_client.client_id}", + Colors.BLUE, + Colors.RED, + ) # Send the job to the worker pod self.ws.write_message(message) @@ -215,7 +215,9 @@ def on_finished(self, message): # Send the reply to the client that made the request. self.connected_client.write_message(message) else: - log(f"Worker {self.id} finished processing, but no client was connected.", "red") + printc( + f"Worker {self.wid} finished processing, but no client was connected.", Colors.RED, Colors.RED + ) self.status = WorkerPodStatus.IDLE self.connected_client = None @@ -236,11 +238,6 @@ class QueueItem: client: WebSocketHandler -workers: dict[str, WorkerPod] = dict() # Keep track of connected worker pods -clients: dict[str, WebSocketHandler] = dict() # Keep track of connected clients -queue: list[QueueItem] = list() # Queue of messages to be processed - - def main(): # Create tornado application and supply URL routes app = tornado.web.Application(WebSocketHandler.urls()) # type: ignore @@ -249,7 +246,7 @@ def main(): http_server = tornado.httpserver.HTTPServer(app) http_server.listen(LISTEN_PORT, LISTEN_ADDRESS) - log(f"Listening on address: {LISTEN_ADDRESS}, {LISTEN_PORT}", "green") + printc(f"Listening on address: {LISTEN_ADDRESS}, {LISTEN_PORT}", Colors.GREEN, Colors.RED) # Start IO/Event loop tornado.ioloop.IOLoop.instance().start() diff --git a/scripts/rubintv_worker.py b/scripts/rubintv_worker.py index ad03071..97152bf 100644 --- a/scripts/rubintv_worker.py +++ b/scripts/rubintv_worker.py @@ -24,7 +24,7 @@ import pathlib import yaml -from lsst.rubintv.analysis.service.client import run_worker +from lsst.rubintv.analysis.service.client import Worker default_config = os.path.join(pathlib.Path(__file__).parent.absolute(), "config.yaml") @@ -47,7 +47,8 @@ def main(): config = yaml.safe_load(file) # Run the client and connect to rubinTV via websockets - run_worker(args.address, args.port, config) + worker = Worker(args.address, args.port, config) + worker.run() if __name__ == "__main__": From d1d9e411f47a0af239a4db56f66315eb771dc6ec Mon Sep 17 00:00:00 2001 From: Fred Moolekamp Date: Mon, 23 Oct 2023 16:37:25 -0400 Subject: [PATCH 3/3] Fix isort errors --- python/lsst/rubintv/analysis/service/client.py | 2 +- scripts/mock_server.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/lsst/rubintv/analysis/service/client.py b/python/lsst/rubintv/analysis/service/client.py index 38fba82..33781b4 100644 --- a/python/lsst/rubintv/analysis/service/client.py +++ b/python/lsst/rubintv/analysis/service/client.py @@ -26,7 +26,7 @@ from websocket import WebSocketApp from .command import DatabaseConnection, execute_command -from .utils import printc, Colors +from .utils import Colors, printc logger = logging.getLogger("lsst.rubintv.analysis.service.client") diff --git a/scripts/mock_server.py b/scripts/mock_server.py index eb3797c..e6fb659 100644 --- a/scripts/mock_server.py +++ b/scripts/mock_server.py @@ -29,8 +29,7 @@ import tornado.ioloop import tornado.web import tornado.websocket - -from lsst.rubintv.analysis.service.utils import printc, Colors +from lsst.rubintv.analysis.service.utils import Colors, printc # Default port and address to listen on LISTEN_PORT = 2000