From 6e537a01763b46e627a257243550643b779cbcfb Mon Sep 17 00:00:00 2001 From: Fred Moolekamp Date: Thu, 21 Sep 2023 17:48:01 -0400 Subject: [PATCH] Initial code and docs commit --- .github/pull_request_template.md | 4 + .github/workflows/build.yaml | 55 +++ .github/workflows/build_docs.yaml | 41 +++ .github/workflows/formatting.yaml | 11 + .github/workflows/lint.yaml | 11 + .github/workflows/yamllint.yaml | 11 + .pre-commit-config.yaml | 32 ++ doc/.gitignore | 10 + doc/SConscript | 3 + doc/conf.py | 12 + doc/doxygen.conf.in | 0 doc/index.rst | 12 + doc/lsst.rubintv.analysis.service/index.rst | 40 +++ doc/manifest.yaml | 12 + mypy.ini | 22 ++ pyproject.toml | 7 +- python/lsst/__init__.py | 2 +- python/lsst/rubintv/__init__.py | 2 +- python/lsst/rubintv/analysis/__init__.py | 2 +- .../lsst/rubintv/analysis/service/__init__.py | 1 + .../lsst/rubintv/analysis/service/command.py | 315 ++++++++++++++++++ .../lsst/rubintv/analysis/service/database.py | 170 ++++++++++ python/lsst/rubintv/analysis/service/query.py | 151 +++++++++ requirements.txt | 7 + tests/schema.yaml | 42 +++ tests/test_command.py | 197 +++++++++++ tests/test_database.py | 103 ++++++ tests/test_query.py | 280 ++++++++++++++++ tests/utils.py | 197 +++++++++++ 29 files changed, 1745 insertions(+), 7 deletions(-) create mode 100644 .github/pull_request_template.md create mode 100644 .github/workflows/build.yaml create mode 100644 .github/workflows/build_docs.yaml create mode 100644 .github/workflows/formatting.yaml create mode 100644 .github/workflows/lint.yaml create mode 100644 .github/workflows/yamllint.yaml create mode 100644 .pre-commit-config.yaml create mode 100644 doc/.gitignore create mode 100644 doc/SConscript create mode 100644 doc/conf.py create mode 100644 doc/doxygen.conf.in create mode 100644 doc/index.rst create mode 100644 doc/lsst.rubintv.analysis.service/index.rst create mode 100644 doc/manifest.yaml create mode 100644 mypy.ini create mode 100644 python/lsst/rubintv/analysis/service/__init__.py create mode 100644 python/lsst/rubintv/analysis/service/command.py create mode 100644 python/lsst/rubintv/analysis/service/database.py create mode 100644 python/lsst/rubintv/analysis/service/query.py create mode 100644 requirements.txt create mode 100644 tests/schema.yaml create mode 100644 tests/test_command.py create mode 100644 tests/test_database.py create mode 100644 tests/test_query.py create mode 100644 tests/utils.py diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..7b345f4 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,4 @@ +## Checklist + +- [ ] ran Jenkins +- [ ] added a release note for user-visible changes to `doc/changes` diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml new file mode 100644 index 0000000..0055df1 --- /dev/null +++ b/.github/workflows/build.yaml @@ -0,0 +1,55 @@ +name: build_and_test + +on: + push: + branches: + - main + tags: + - "*" + pull_request: + +jobs: + build_and_test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + # Need to clone everything for the git tags. + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + cache: "pip" + cache-dependency-path: "setup.cfg" + + - name: Install yaml + run: sudo apt-get install libyaml-dev + + - name: Install prereqs for setuptools + run: pip install wheel + + # We have two cores so we can speed up the testing with xdist + - name: Install xdist, openfiles and flake8 for pytest + run: > + pip install pytest-xdist pytest-openfiles pytest-flake8 + pytest-cov "flake8<5" + + - name: Build and install + run: pip install -v -e . + + - name: Install documenteer + run: pip install 'documenteer[pipelines]<0.7' + + - name: Run tests + run: > + pytest -r a -v -n 3 --open-files --cov=tests + --cov=lsst.rubintv.analysis.service + --cov-report=xml --cov-report=term + --doctest-modules --doctest-glob="*.rst" + + - name: Upload coverage to codecov + uses: codecov/codecov-action@v2 + with: + file: ./coverage.xml diff --git a/.github/workflows/build_docs.yaml b/.github/workflows/build_docs.yaml new file mode 100644 index 0000000..75d1dac --- /dev/null +++ b/.github/workflows/build_docs.yaml @@ -0,0 +1,41 @@ +name: docs + +on: + push: + branches: + - main + pull_request: + +jobs: + build_sphinx_docs: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + # Need to clone everything for the git tags. + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + cache: "pip" + cache-dependency-path: "setup.cfg" + + - name: Update pip/wheel infrastructure + run: | + python -m pip install --upgrade pip + pip install wheel + + - name: Build and install + run: pip install -v -e . + + - name: Show compiled files + run: ls python/lsst/rubintv/analysis/service + + - name: Install documenteer + run: pip install 'documenteer[pipelines]<0.7' + + - name: Build documentation + working-directory: ./doc + run: package-docs build diff --git a/.github/workflows/formatting.yaml b/.github/workflows/formatting.yaml new file mode 100644 index 0000000..27f34a6 --- /dev/null +++ b/.github/workflows/formatting.yaml @@ -0,0 +1,11 @@ +name: Check Python formatting + +on: + push: + branches: + - main + pull_request: + +jobs: + call-workflow: + uses: lsst/rubin_workflows/.github/workflows/formatting.yaml@main diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 0000000..796ef92 --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,11 @@ +name: lint + +on: + push: + branches: + - main + pull_request: + +jobs: + call-workflow: + uses: lsst/rubin_workflows/.github/workflows/lint.yaml@main diff --git a/.github/workflows/yamllint.yaml b/.github/workflows/yamllint.yaml new file mode 100644 index 0000000..76ad875 --- /dev/null +++ b/.github/workflows/yamllint.yaml @@ -0,0 +1,11 @@ +name: Lint YAML Files + +on: + push: + branches: + - main + pull_request: + +jobs: + call-workflow: + uses: lsst/rubin_workflows/.github/workflows/yamllint.yaml@main diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a07ff8a --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,32 @@ +repos: + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-yaml + args: + - "--unsafe" + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + # It is recommended to specify the latest version of Python + + # supported by your project here, or alternatively use + + # pre-commit's default_language_version, see + + # https://pre-commit.com/#top_level-default_language_version + + language_version: python3.10 + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) + - repo: https://github.com/PyCQA/flake8 + rev: 6.0.0 + hooks: + - id: flake8 diff --git a/doc/.gitignore b/doc/.gitignore new file mode 100644 index 0000000..ad2c2bb --- /dev/null +++ b/doc/.gitignore @@ -0,0 +1,10 @@ +# Doxygen products +html +xml +*.tag +*.inc +doxygen.conf + +# Sphinx products +_build +py-api diff --git a/doc/SConscript b/doc/SConscript new file mode 100644 index 0000000..61b554a --- /dev/null +++ b/doc/SConscript @@ -0,0 +1,3 @@ +# -*- python -*- +from lsst.sconsUtils import scripts +scripts.BasicSConscript.doc() diff --git a/doc/conf.py b/doc/conf.py new file mode 100644 index 0000000..0dcae36 --- /dev/null +++ b/doc/conf.py @@ -0,0 +1,12 @@ +"""Sphinx configuration file for an LSST stack package. +This configuration only affects single-package Sphinx documentation builds. +For more information, see: +https://developer.lsst.io/stack/building-single-package-docs.html +""" + +from documenteer.conf.pipelinespkg import * + +project = "rubintv_analysis_service" +html_theme_options["logotext"] = project +html_title = project +html_short_title = project diff --git a/doc/doxygen.conf.in b/doc/doxygen.conf.in new file mode 100644 index 0000000..e69de29 diff --git a/doc/index.rst b/doc/index.rst new file mode 100644 index 0000000..f7713f5 --- /dev/null +++ b/doc/index.rst @@ -0,0 +1,12 @@ +############################################## +rubintv_analysis_service documentation preview +############################################## + +.. This page is for local development only. It isn't published to pipelines.lsst.io. + +.. Link the index pages of package and module documentation directions (listed in manifest.yaml). + +.. toctree:: + :maxdepth: 1 + + lsst.rubintv.analysis.service/index diff --git a/doc/lsst.rubintv.analysis.service/index.rst b/doc/lsst.rubintv.analysis.service/index.rst new file mode 100644 index 0000000..3707f25 --- /dev/null +++ b/doc/lsst.rubintv.analysis.service/index.rst @@ -0,0 +1,40 @@ +.. py:currentmodule:: lsst.rubintv.analysis.service + +.. _lsst.rubintv.analysis.service: + +############################# +lsst.rubintv.analysis.service +############################# + +.. Paragraph that describes what this Python module does and links to related modules and frameworks. + +.. _lsst.rubintv.analysis.service-using: + +Using lsst.rubintv.analysis.service +======================= + +toctree linking to topics related to using the module's APIs. + +.. toctree:: + :maxdepth: 2 + +.. _lsst.rubintv.analysis.service-contributing: + +Contributing +============ + +``lsst.rubintv.analysis.service`` is developed at https://github.com/lsst-ts/rubintv_analysis_service. + +.. If there are topics related to developing this module (rather than using it), link to this from a toctree placed here. + +.. .. toctree:: +.. :maxdepth: 2 + +.. _lsst.rubintv.analysis.service-pyapi: + +Python API reference +==================== + +.. automodapi:: lsst.rubintv.analysis.service + :no-main-docstr: + :no-inheritance-diagram: diff --git a/doc/manifest.yaml b/doc/manifest.yaml new file mode 100644 index 0000000..222eb97 --- /dev/null +++ b/doc/manifest.yaml @@ -0,0 +1,12 @@ +# Documentation manifest. + +# List of names of Python modules in this package. +# For each module there is a corresponding module doc subdirectory. +modules: + - "lsst.rubintv.analysis.service" + +# Name of the static content directories (subdirectories of `_static`). +# Static content directories are usually named after the package. +# Most packages do not need a static content directory (leave commented out). +# statics: +# - "_static/example_standalone" diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..3d32554 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,22 @@ +[mypy] +warn_unused_configs = True +warn_redundant_casts = True +plugins = pydantic.mypy + +[mypy-astropy.*] +ignore_missing_imports = True + +[mypy-matplotlib.*] +ignore_missing_imports = True + +[mypy-numpy.*] +ignore_missing_imports = True + +[mypy-scipy.*] +ignore_missing_imports = True + +[mypy-sqlalchemy.*] +ignore_missing_imports = True + +[mypy-yaml.*] +ignore_missing_imports = True diff --git a/pyproject.toml b/pyproject.toml index 3b951cd..5a4e43d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,9 @@ dependencies = [ "scipy", "matplotlib", "pydantic", - "pyyaml" + "pyyaml", + "sqlalchemy", + "astropy", ] #dynamic = ["version"] @@ -65,9 +67,6 @@ line_length = 110 [tool.lsst_versions] write_to = "python/lsst/rubintv/analysis/service/version.py" -[tool.pytest.ini_options] -addopts = "--flake8" -flake8-ignore = ["W503", "E203"] # The matplotlib test may not release font files. open_files_ignore = ["*.ttf"] diff --git a/python/lsst/__init__.py b/python/lsst/__init__.py index eb1f6e6..f77af49 100644 --- a/python/lsst/__init__.py +++ b/python/lsst/__init__.py @@ -1,3 +1,3 @@ import pkgutil -__path__ = pkgutil.extend_path(__path__, __name__) \ No newline at end of file +__path__ = pkgutil.extend_path(__path__, __name__) diff --git a/python/lsst/rubintv/__init__.py b/python/lsst/rubintv/__init__.py index eb1f6e6..f77af49 100644 --- a/python/lsst/rubintv/__init__.py +++ b/python/lsst/rubintv/__init__.py @@ -1,3 +1,3 @@ import pkgutil -__path__ = pkgutil.extend_path(__path__, __name__) \ No newline at end of file +__path__ = pkgutil.extend_path(__path__, __name__) diff --git a/python/lsst/rubintv/analysis/__init__.py b/python/lsst/rubintv/analysis/__init__.py index eb1f6e6..f77af49 100644 --- a/python/lsst/rubintv/analysis/__init__.py +++ b/python/lsst/rubintv/analysis/__init__.py @@ -1,3 +1,3 @@ import pkgutil -__path__ = pkgutil.extend_path(__path__, __name__) \ No newline at end of file +__path__ = pkgutil.extend_path(__path__, __name__) diff --git a/python/lsst/rubintv/analysis/service/__init__.py b/python/lsst/rubintv/analysis/service/__init__.py new file mode 100644 index 0000000..e144f74 --- /dev/null +++ b/python/lsst/rubintv/analysis/service/__init__.py @@ -0,0 +1 @@ +from . import command, database, query diff --git a/python/lsst/rubintv/analysis/service/command.py b/python/lsst/rubintv/analysis/service/command.py new file mode 100644 index 0000000..10e868d --- /dev/null +++ b/python/lsst/rubintv/analysis/service/command.py @@ -0,0 +1,315 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import json +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import sqlalchemy + +from . import database + +logger = logging.getLogger("lsst.rubintv.analysis.service.command") + + +def construct_error_message(error_name: str, description: str) -> str: + """Use a standard format for all error messages. + + Parameters + ---------- + error_name : + Name of the error. + description : + Description of the error. + + Returns + ------- + result : + JSON formatted string. + """ + return json.dumps( + { + "type": "error", + "content": { + "error": error_name, + "description": description, + }, + } + ) + + +def error_msg(error: Exception) -> str: + """Handle errors received while parsing or executing a command. + + Parameters + ---------- + error : + The error that was raised while parsing, executing, + or responding to a command. + + Returns + ------- + response : + The JSON formatted error message sent to the user. + + """ + if isinstance(error, json.decoder.JSONDecodeError): + return construct_error_message("JSON decoder error", error.args[0]) + + if isinstance(error, CommandParsingError): + return construct_error_message("parsing error", error.args[0]) + + if isinstance(error, CommandExecutionError): + return construct_error_message("execution error", error.args[0]) + + if isinstance(error, CommandResponseError): + return construct_error_message("command response error", error.args[0]) + + # We should always receive one of the above errors, so the code should + # never get to here. But we generate this response just in case something + # very unexpected happens, or (more likely) the code is altered in such a + # way that this line is it. + msg = "An unknown error occurred, you should never reach this message." + return construct_error_message(error.__class__.__name__, msg) + + +class CommandParsingError(Exception): + """An `~Exception` caused by an error in parsing a command and + constructing a response. + """ + + pass + + +class CommandExecutionError(Exception): + """An error occurred while executing a command.""" + + pass + + +class CommandResponseError(Exception): + """An error occurred while converting a command result to JSON""" + + pass + + +@dataclass(kw_only=True) +class BaseCommand(ABC): + """Base class for commands. + + Attributes + ---------- + result : + The response generated by the command as a `dict` that can + be converted into JSON. + response_type : + The type of response that this command sends to the user. + This should be unique for each command. + """ + + result: dict | None = None + response_type: str + + @abstractmethod + def build_contents(self, schema: dict, engine: sqlalchemy.engine.Engine) -> dict: + """Build the contents of the command. + + Parameters + ---------- + schema : + The schema of the full visit database. + engine : + The engine used to connect to the database. + + Returns + ------- + contents : + The contents of the response to the user. + """ + pass + + def execute(self, schema: dict, engine: sqlalchemy.engine.Engine): + """Execute the command. + + This method does not return anything, buts sets the `result`, + the JSON formatted string that is sent to the user. + + Parameters + ---------- + schema : + The schema of the full visit database. + engine : + The engine used to connect to the database. + + """ + self.result = {"type": self.response_type, "content": self.build_contents(schema, engine)} + + def to_json(self): + """Convert the `result` into JSON.""" + if self.result is None: + raise CommandExecutionError(f"Null result for command {self.__class__.__name__}") + return json.dumps(self.result) + + +@dataclass(kw_only=True) +class LoadColumnsCommand(BaseCommand): + """Load columns from a database table with an optional query. + + Attributes + ---------- + table : + The table that the columns are loaded from. + columns : + Columns that are to be loaded. If `columns` is ``None`` + then all the columns in the `table` are loaded. + query : + Query used to select rows in the table. + If `query` is ``None`` then all the rows are loaded. + """ + + table: str + columns: list[str] | None = None + query: dict | None = None + response_type: str = "load columns" + + def build_contents(self, schema: dict, engine: sqlalchemy.engine.Engine) -> dict: + # Query the database to return the requested columns + index_column = database.get_table_schema(schema, self.table)["index_column"] + columns = self.columns + if columns is not None and index_column not in columns: + columns = [index_column] + columns + data = database.query_table( + table=self.table, + columns=columns, + query=self.query, + engine=engine, + ) + + if len(data) == 0: + # There is no column data to return + content: dict = { + "columns": columns, + "data": [], + } + else: + content = { + "columns": [column for column in data[0]._fields], + "data": [list(row) for row in data], + } + + return content + + +@dataclass(kw_only=True) +class CalculateBoundsCommand(BaseCommand): + """Calculate the bounds of a table column. + + Attributes + ---------- + table : + The table that the columns are loaded from. + column : + The column to calculate the bounds of. + """ + + table: str + column: str + response_type: str = "column bounds" + + def build_contents(self, schema: dict, engine: sqlalchemy.engine.Engine) -> dict: + data = database.calculate_bounds( + table=self.table, + column=self.column, + engine=engine, + ) + return { + "column": self.column, + "bounds": data, + } + + +def parse_command(command_str: str, schema: dict, engine: sqlalchemy.engine.Engine) -> str: + """Parse a JSON formatted string into a command that + the service can execute + + Command format: + ``` + { + name: command name, + content: command content (usually a dict) + } + ``` + + Parameters + ---------- + command_str : + The JSON formatted command received from the user. + schema : + The schema of the full visit database. + engine : + The engine used to connect to the database. + """ + try: + command_dict = json.loads(command_str) + if not isinstance(command_dict, dict): + raise CommandParsingError(f"Could not generate a valid command from {command_str}") + except Exception as err: + logging.exception("Error converting command to JSON.") + return error_msg(err) + + try: + if "name" not in command_dict.keys(): + raise CommandParsingError("No command 'name' given") + + if command_dict["name"] not in commands.keys(): + raise CommandParsingError(f"Unrecognized command '{command_dict['name']}'") + + if "parameters" in command_dict: + parameters = command_dict["parameters"] + else: + parameters = {} + + command = commands[command_dict["name"]](**parameters) + + except Exception as err: + logging.exception("Error parsing command.") + return error_msg(CommandParsingError(f"'{err}' error while parsing command")) + + try: + command.execute(schema, engine) + except Exception as err: + logging.exception("Error executing command.") + return error_msg(CommandExecutionError(f"{err} error executing command.")) + + try: + result = command.to_json() + except Exception as err: + logging.exception("Error converting command response to JSON.") + return error_msg(CommandResponseError(f"{err} error converting command response to JSON.")) + + return result + + +commands = { + "load columns": LoadColumnsCommand, + "get bounds": CalculateBoundsCommand, +} diff --git a/python/lsst/rubintv/analysis/service/database.py b/python/lsst/rubintv/analysis/service/database.py new file mode 100644 index 0000000..5306f6c --- /dev/null +++ b/python/lsst/rubintv/analysis/service/database.py @@ -0,0 +1,170 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from typing import Sequence + +import sqlalchemy + +from .query import Query + + +class UnrecognizedTableError(Exception): + """An error that occurs when a table name does not appear in the schema""" + + pass + + +def get_table_names(schema: dict) -> tuple[str, ...]: + """Given a schema, return a list of dataset names + + Parameters + ---------- + schema : + The schema for a database. + + Returns + ------- + result : + The names of all the tables in the database. + """ + return tuple(tbl["name"] for tbl in schema["tables"]) + + +def get_table_schema(schema: dict, table: str) -> dict: + """Get the schema for a table from the database schema + + Parameters + ---------- + schema: + The schema for a database. + table: + The name of the table in the database. + + Returns + ------- + result: + The schema for the table. + """ + tables = schema["tables"] + for _table in tables: + if _table["name"] == table: + return _table + raise UnrecognizedTableError("Could not find the table '{table}' in database") + + +def column_names_to_models(table: sqlalchemy.Table, columns: list[str]) -> list[sqlalchemy.Column]: + """Return the sqlalchemy model of a Table column for each column name. + + This method is used to generate a sqlalchemy query based on a `~Query`. + + Parameters + ---------- + table : + The name of the table in the database. + columns : + The names of the columns to generate models for. + + Returns + ------- + A list of sqlalchemy columns. + """ + models = [] + for column in columns: + models.append(getattr(table.columns, column)) + return models + + +def query_table( + table: str, + engine: sqlalchemy.engine.Engine, + columns: list[str] | None = None, + query: dict | None = None, +) -> Sequence[sqlalchemy.engine.row.Row]: + """Query a table and return the results + + Parameters + ---------- + engine : + The engine used to connect to the database. + table : + The table that is being queried. + columns : + The columns from the table to return. + If `columns` is ``None`` then all the columns + in the table are returned. + query : + A query used on the table. + If `query` is ``None`` then all the rows + in the query are returned. + + Returns + ------- + result : + A list of the rows that were returned by the query. + """ + metadata = sqlalchemy.MetaData() + _table = sqlalchemy.Table(table, metadata, autoload_with=engine) + + if columns is None: + _query = _table.select() + else: + _query = sqlalchemy.select(*column_names_to_models(_table, columns)) + + if query is not None: + _query = _query.where(Query.from_dict(query)(_table)) + + connection = engine.connect() + result = connection.execute(_query) + return result.fetchall() + + +def calculate_bounds(table: str, column: str, engine: sqlalchemy.engine.Engine) -> tuple[float, float]: + """Calculate the min, max for a column + + Parameters + ---------- + table : + The table that is being queried. + column : + The column to calculate the bounds of. + engine : + The engine used to connect to the database. + + Returns + ------- + result : + The ``(min, max)`` of the chosen column. + """ + metadata = sqlalchemy.MetaData() + _table = sqlalchemy.Table(table, metadata, autoload_with=engine) + _column = _table.columns[column] + + query = sqlalchemy.select((sqlalchemy.func.min(_column))) + connection = engine.connect() + result = connection.execute(query) + col_min = result.fetchone()[0] + + query = sqlalchemy.select((sqlalchemy.func.max(_column))) + connection = engine.connect() + result = connection.execute(query) + col_max = result.fetchone()[0] + + return col_min, col_max diff --git a/python/lsst/rubintv/analysis/service/query.py b/python/lsst/rubintv/analysis/service/query.py new file mode 100644 index 0000000..49d4fa4 --- /dev/null +++ b/python/lsst/rubintv/analysis/service/query.py @@ -0,0 +1,151 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import operator as op +from abc import ABC, abstractmethod +from typing import Any + +import sqlalchemy + + +class QueryError(Exception): + """An error that occurred during a query""" + + pass + + +class Query(ABC): + """Base class for constructing queries.""" + + @abstractmethod + def __call__(self, table: sqlalchemy.Table) -> sqlalchemy.sql.elements.BooleanClauseList: + """Run the query on a table. + + Parameters + ---------- + table : + The table to run the query on. + """ + pass + + @staticmethod + def from_dict(query_dict: dict[str, Any]) -> Query: + """Construct a query from a dictionary of parameters. + + Parameters + ---------- + query_dict : + Kwargs used to initialize the query. + There should only be two keys in this dict, + the ``name`` of the query and the ``content`` used + to initialize the query. + """ + try: + if query_dict["name"] == "EqualityQuery": + return EqualityQuery.from_dict(query_dict["content"]) + elif query_dict["name"] == "ParentQuery": + return ParentQuery.from_dict(query_dict["content"]) + except Exception: + raise QueryError("Failed to parse query.") + + raise QueryError("Unrecognized query type") + + +class EqualityQuery(Query): + """A query that compares a column to a static value. + + Parameters + ---------- + column : + The column used in the query. + operator : + The operator to use for the query. + value : + The value that the column is compared to. + """ + + def __init__( + self, + column: str, + operator: str, + value: Any, + ): + self.operator = operator + self.column = column + self.value = value + + def __call__(self, table: sqlalchemy.Table) -> sqlalchemy.sql.elements.BooleanClauseList: + column = table.columns[self.column] + + if self.operator in ("eq", "ne", "lt", "le", "gt", "ge"): + operator = getattr(op, self.operator) + return operator(column, self.value) + + if self.operator not in ("startswith", "endswith", "contains"): + raise QueryError(f"Unrecognized Equality operator {self.operator}") + + return getattr(column, self.operator)(self.value) + + @staticmethod + def from_dict(query_dict: dict[str, Any]) -> EqualityQuery: + return EqualityQuery(**query_dict) + + +class ParentQuery(Query): + """A query that uses a binary operation to combine other queries. + + Parameters + ---------- + children : + The child queries that are combined using the binary operator. + operator : + The operator that us used to combine the queries. + """ + + def __init__(self, children: list[Query], operator: str): + self.children = children + self.operator = operator + + def __call__(self, table: sqlalchemy.Table) -> sqlalchemy.sql.elements.BooleanClauseList: + child_results = [child(table) for child in self.children] + try: + if self.operator == "AND": + return sqlalchemy.and_(*child_results) + if self.operator == "OR": + return sqlalchemy.or_(*child_results) + if self.operator == "NOT": + return sqlalchemy.not_(*child_results) + if self.operator == "XOR": + return sqlalchemy.and_( + sqlalchemy.or_(*child_results), + sqlalchemy.not_(sqlalchemy.and_(*child_results)), + ) + except Exception: + raise QueryError("Error applying a boolean query statement.") + + @staticmethod + def from_dict(query_dict: dict[str, Any]) -> ParentQuery: + return ParentQuery( + children=[Query.from_dict(child) for child in query_dict["children"]], + operator=query_dict["operator"], + ) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..cae54f4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +numpy>=1.25.2 +scipy +matplotlib +pydantic +pyyaml +sqlalchemy +astropy diff --git a/tests/schema.yaml b/tests/schema.yaml new file mode 100644 index 0000000..835499d --- /dev/null +++ b/tests/schema.yaml @@ -0,0 +1,42 @@ +--- +name: testdb +"@id": "#test_db" +description: Small database for testing the package +tables: + - name: ExposureInfo + index_column: exposure_id + columns: + - name: exposure_id + datatype: long + description: Unique identifier of an exposure. + - name: seq_num + datatype: long + description: Sequence number + - name: ra + datatype: double + unit: degree + description: RA of focal plane center. + - name: dec + datatype: double + unit: degree + description: Declination of focal plane center + - name: expTime + datatype: double + description: Spatially-averaged duration of exposure, accurate to 10ms. + - name: physical_filter + datatype: char + description: ID of physical filter, + the filter associated with a particular instrument. + - name: obsNight + datatype: date + description: The night of the observation. This is different than the + observation date, as this is the night that the observations started, + so for observations after midnight obsStart and obsNight will be + different days. + - name: obsStart + datatype: datetime + description: Start time of the exposure at the fiducial center + of the focal plane array, TAI, accurate to 10ms. + - name: obsStartMJD + datatype: double + description: Start of the exposure in MJD, TAI, accurate to 10ms. diff --git a/tests/test_command.py b/tests/test_command.py new file mode 100644 index 0000000..a5a2941 --- /dev/null +++ b/tests/test_command.py @@ -0,0 +1,197 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import json +import os +import tempfile + +import lsst.rubintv.analysis.service as lras +import sqlalchemy +import utils +import yaml + + +class TestCommand(utils.RasTestCase): + def setUp(self): + path = os.path.dirname(__file__) + yaml_filename = os.path.join(path, "schema.yaml") + + with open(yaml_filename) as file: + schema = yaml.safe_load(file) + db_file = tempfile.NamedTemporaryFile(delete=False) + utils.create_database(schema, db_file.name) + self.db_file = db_file + self.db_filename = db_file.name + self.schema = schema + + # Set up the sqlalchemy connection + self.engine = sqlalchemy.create_engine("sqlite:///" + db_file.name) + + def tearDown(self) -> None: + self.db_file.close() + os.remove(self.db_file.name) + + def execute_command(self, command: dict, response_type: str) -> dict: + command_json = json.dumps(command) + response = lras.command.parse_command(command_json, self.schema, self.engine) + result = json.loads(response) + self.assertEqual(result["type"], response_type) + return result["content"] + + +class TestCalculateBoundsCommand(TestCommand): + def test_calculate_bounds_command(self): + command = { + "name": "get bounds", + "parameters": { + "table": "ExposureInfo", + "column": "dec", + }, + } + content = self.execute_command(command, "column bounds") + self.assertEqual(content["column"], "dec") + self.assertListEqual(content["bounds"], [-40, 50]) + + +class TestLoadColumnsCommand(TestCommand): + def test_load_full_dataset(self): + command = {"name": "load columns", "parameters": {"table": "ExposureInfo"}} + + content = self.execute_command(command, "load columns") + data = content["data"] + + truth = utils.ap_table_to_list(utils.get_test_data()) + + self.assertDataTableEqual(data, truth) + + def test_load_full_columns(self): + command = { + "name": "load columns", + "parameters": { + "table": "ExposureInfo", + "columns": [ + "ra", + "dec", + ], + }, + } + + content = self.execute_command(command, "load columns") + columns = content["columns"] + data = content["data"] + + truth = utils.get_test_data()["exposure_id", "ra", "dec"] + truth_data = utils.ap_table_to_list(truth) + + self.assertTupleEqual(tuple(columns), tuple(truth.columns)) + self.assertDataTableEqual(data, truth_data) + + def test_load_columns_with_query(self): + command = { + "name": "load columns", + "parameters": { + "table": "ExposureInfo", + "columns": [ + "exposure_id", + "ra", + "dec", + ], + "query": { + "name": "EqualityQuery", + "content": { + "column": "expTime", + "operator": "eq", + "value": 30, + }, + }, + }, + } + + content = self.execute_command(command, "load columns") + columns = content["columns"] + data = content["data"] + + truth = utils.get_test_data()["exposure_id", "ra", "dec"] + # Select rows with expTime = 30 + truth = truth[[True, True, False, False, False, True, True, True, False, False]] + truth_data = utils.ap_table_to_list(truth) + + self.assertTupleEqual(tuple(columns), tuple(truth.columns)) + self.assertDataTableEqual(data, truth_data) + + +class TestCommandErrors(TestCommand): + def check_error_response(self, content: dict, error: str, description: str | None = None): + self.assertEqual(content["error"], error) + if description is not None: + self.assertEqual(content["description"], description) + + def test_errors(self): + # Command cannot be decoded as JSON dict + content = self.execute_command("{'test': [1,2,3,0004,}", "error") + self.check_error_response(content, "parsing error") + + # Command does not contain a "name" + command = {"content": {}} + content = self.execute_command(command, "error") + self.check_error_response( + content, + "parsing error", + "'No command 'name' given' error while parsing command", + ) + + # Command has an invalid name + command = {"name": "invalid name"} + content = self.execute_command(command, "error") + self.check_error_response( + content, + "parsing error", + "'Unrecognized command 'invalid name'' error while parsing command", + ) + + # Command has no parameters + command = {"name": "get bounds"} + content = self.execute_command(command, "error") + self.check_error_response( + content, + "parsing error", + ) + + # Command has invalid parameters + command = { + "name": "get bounds", + "parameters": { + "a": 1, + }, + } + content = self.execute_command(command, "error") + self.check_error_response( + content, + "parsing error", + ) + + # Command execution failed (table name does not exist) + command = {"name": "get bounds", "parameters": {"table": "InvalidTable", "column": "invalid_column"}} + content = self.execute_command(command, "error") + self.check_error_response( + content, + "execution error", + ) diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..94d6914 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,103 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import os +import tempfile +from unittest import TestCase + +import lsst.rubintv.analysis.service as lras +import sqlalchemy +import utils +import yaml + + +class TestDatabase(TestCase): + def setUp(self): + path = os.path.dirname(__file__) + yaml_filename = os.path.join(path, "schema.yaml") + + with open(yaml_filename) as file: + schema = yaml.safe_load(file) + db_file = tempfile.NamedTemporaryFile(delete=False) + utils.create_database(schema, db_file.name) + self.db_file = db_file + self.db_filename = db_file.name + self.schema = schema + + # Set up the sqlalchemy connection + self.engine = sqlalchemy.create_engine("sqlite:///" + db_file.name) + + def tearDown(self) -> None: + self.db_file.close() + os.remove(self.db_file.name) + + def test_get_table_names(self): + table_names = lras.database.get_table_names(self.schema) + self.assertTupleEqual(table_names, ("ExposureInfo",)) + + def test_get_table_schema(self): + schema = lras.database.get_table_schema(self.schema, "ExposureInfo") + self.assertEqual(schema["name"], "ExposureInfo") + + columns = [ + "exposure_id", + "seq_num", + "ra", + "dec", + "expTime", + "physical_filter", + "obsNight", + "obsStart", + "obsStartMJD", + ] + for n, column in enumerate(schema["columns"]): + self.assertEqual(column["name"], columns[n]) + + def test_query_full_table(self): + truth_table = utils.get_test_data() + truth = utils.ap_table_to_list(truth_table) + + data = lras.database.query_table("ExposureInfo", engine=self.engine) + print(data) + + self.assertListEqual(list(data[0]._fields), list(truth_table.columns)) + + for n in range(len(truth)): + true_row = tuple(truth[n]) + row = tuple(data[n]) + self.assertTupleEqual(row, true_row) + + def test_query_columns(self): + truth = utils.get_test_data() + truth = utils.ap_table_to_list(truth["ra", "dec"]) + + data = lras.database.query_table("ExposureInfo", columns=["ra", "dec"], engine=self.engine) + + self.assertListEqual(list(data[0]._fields), ["ra", "dec"]) + + for n in range(len(truth)): + true_row = tuple(truth[n]) + row = tuple(data[n]) + self.assertTupleEqual(row, true_row) + + def test_calculate_bounds(self): + result = lras.database.calculate_bounds("ExposureInfo", "dec", self.engine) + self.assertTupleEqual(result, (-40, 50)) diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 0000000..a125218 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,280 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import os +import tempfile + +import lsst.rubintv.analysis.service as lras +import sqlalchemy +import utils +import yaml + + +class TestQuery(utils.RasTestCase): + def setUp(self): + path = os.path.dirname(__file__) + yaml_filename = os.path.join(path, "schema.yaml") + + with open(yaml_filename) as file: + schema = yaml.safe_load(file) + db_file = tempfile.NamedTemporaryFile(delete=False) + utils.create_database(schema, db_file.name) + self.db_file = db_file + self.db_filename = db_file.name + self.schema = schema + + # Set up the sqlalchemy connection + self.engine = sqlalchemy.create_engine("sqlite:///" + db_file.name) + self.metadata = sqlalchemy.MetaData() + self.table = sqlalchemy.Table("ExposureInfo", self.metadata, autoload_with=self.engine) + + def tearDown(self) -> None: + self.db_file.close() + os.remove(self.db_file.name) + + def test_equality(self): + table = self.table + column = table.columns.dec + + value = 0 + truth_dict = { + "eq": column == value, + "ne": column != value, + "lt": column < value, + "le": column <= value, + "gt": column > value, + "ge": column >= value, + } + + for operator, truth in truth_dict.items(): + self.assertTrue(lras.query.EqualityQuery("dec", operator, value)(table).compare(truth)) + + def test_query(self): + table = self.table + + # dec > 0 + query = lras.query.EqualityQuery("dec", "gt", 0) + result = query(table) + self.assertTrue(result.compare(table.columns.dec > 0)) + + # dec < 0 and ra > 60 + query = lras.query.ParentQuery( + operator="AND", + children=[ + lras.query.EqualityQuery("dec", "lt", 0), + lras.query.EqualityQuery("ra", "gt", 60), + ], + ) + result = query(table) + truth = sqlalchemy.and_( + table.columns.dec < 0, + table.columns.ra > 60, + ) + self.assertTrue(result.compare(truth)) + + # Check queries that are unequal to verify that they don't work + result = query(table) + truth = sqlalchemy.and_( + table.columns.dec < 0, + table.columns.ra > 70, + ) + self.assertFalse(result.compare(truth)) + + def test_database_query(self): + data = utils.get_test_data() + + # dec > 0 (and is not None) + query1 = { + "name": "EqualityQuery", + "content": { + "column": "dec", + "operator": "gt", + "value": 0, + }, + } + # ra > 60 (and is not None) + query2 = { + "name": "EqualityQuery", + "content": { + "column": "ra", + "operator": "gt", + "value": 60, + }, + } + + # Test 1: dec > 0 (and is not None) + query = query1 + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + truth = data[[False, False, False, False, False, True, False, True, True, True]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + # Test 2: dec > 0 and ra > 60 (and neither is None) + query = { + "name": "ParentQuery", + "content": { + "operator": "AND", + "children": [query1, query2], + }, + } + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + truth = data[[False, False, False, False, False, False, False, False, True, True]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + # Test 3: dec <= 0 or ra > 60 (and neither is None) + query = { + "name": "ParentQuery", + "content": { + "operator": "OR", + "children": [ + { + "name": "ParentQuery", + "content": { + "operator": "NOT", + "children": [query1], + }, + }, + query2, + ], + }, + } + + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + truth = data[[True, True, False, True, True, False, True, False, True, True]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + # Test 4: dec > 0 XOR ra > 60 + query = { + "name": "ParentQuery", + "content": { + "operator": "XOR", + "children": [query1, query2], + }, + } + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + truth = data[[False, False, False, False, False, True, False, False, False, False]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + def test_database_string_query(self): + data = utils.get_test_data() + + # Test equality + query = { + "name": "EqualityQuery", + "content": { + "column": "physical_filter", + "operator": "eq", + "value": "DECam r-band", + }, + } + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + truth = data[[False, False, False, False, False, False, True, False, False, False]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + # Test "startswith" + query = { + "name": "EqualityQuery", + "content": { + "column": "physical_filter", + "operator": "startswith", + "value": "DECam", + }, + } + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + truth = data[[False, False, False, False, False, True, True, True, True, True]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + # Test "endswith" + query = { + "name": "EqualityQuery", + "content": { + "column": "physical_filter", + "operator": "endswith", + "value": "r-band", + }, + } + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + truth = data[[False, True, False, False, False, False, True, False, False, False]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + # Test "like" + query = { + "name": "EqualityQuery", + "content": { + "column": "physical_filter", + "operator": "contains", + "value": "T r", + }, + } + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + truth = data[[False, True, False, False, False, False, False, False, False, False]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + def test_database_datatime_query(self): + data = utils.get_test_data() + + # Test < + query1 = { + "name": "EqualityQuery", + "content": { + "column": "obsStart", + "operator": "lt", + "value": "2023-05-19 23:23:23", + }, + } + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query1) + truth = data[[True, True, True, False, False, True, True, True, True, True]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + # Test > + query2 = { + "name": "EqualityQuery", + "content": { + "column": "obsStart", + "operator": "gt", + "value": "2023-05-01 23:23:23", + }, + } + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query2) + truth = data[[True, True, True, True, True, False, False, False, False, False]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) + + # Test in range + query3 = { + "name": "ParentQuery", + "content": { + "operator": "AND", + "children": [query1, query2], + }, + } + result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query3) + truth = data[[True, True, True, False, False, False, False, False, False, False]] + truth = utils.ap_table_to_list(truth) + self.assertDataTableEqual(result, truth) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..5b78e5b --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,197 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import sqlite3 +from unittest import TestCase + +from astropy.table import Table as ApTable +from astropy.time import Time + +# Convert visit DB datatypes to sqlite3 datatypes +datatype_transform = { + "int": "integer", + "long": "integer", + "double": "real", + "float": "real", + "char": "text", + "date": "text", + "datetime": "text", +} + + +def create_table(cursor: sqlite3.Cursor, tbl_name: str, schema: dict): + """Create a table in an sqlite database. + + Parameters + ---------- + cursor : + The cursor associated with the database connection. + tbl_name : + The name of the table to create. + schema : + The schema of the table. + """ + command = f"CREATE TABLE {tbl_name}(\n" + for field in schema: + command += f' {field["name"]} {datatype_transform[field["datatype"]]},\n' + command = command[:-2] + "\n);" + cursor.execute(command) + + +def get_test_data_dict() -> dict: + """Get a dictionary containing the test data""" + obs_start = [ + "2023-05-19 20:20:20", + "2023-05-19 21:21:21", + "2023-05-19 22:22:22", + "2023-05-19 23:23:23", + "2023-05-20 00:00:00", + "2023-02-14 22:22:22", + "2023-02-14 23:23:23", + "2023-02-14 00:00:00", + "2023-02-14 01:01:01", + "2023-02-14 02:02:02", + ] + + obs_start_mjd = [Time(time).mjd for time in obs_start] + + return { + "exposure_id": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18], + "seq_num": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + "ra": [10, 20, None, 40, 50, 60, 70, None, 90, 100], + "dec": [-40, -30, None, -10, 0, 10, None, 30, 40, 50], + "expTime": [30, 30, 10, 15, 15, 30, 30, 30, 15, 20], + "physical_filter": [ + "LSST g-band", + "LSST r-band", + "LSST i-band", + "LSST z-band", + "LSST y-band", + "DECam g-band", + "DECam r-band", + "DECam i-band", + "DECam z-band", + "DECam y-band", + ], + "obsNight": [ + "2023-05-19", + "2023-05-19", + "2023-05-19", + "2023-05-19", + "2023-05-19", + "2023-02-14", + "2023-02-14", + "2023-02-14", + "2023-02-14", + "2023-02-14", + ], + "obsStart": obs_start, + "obsStartMJD": obs_start_mjd, + } + + +def get_test_data() -> ApTable: + """Generate data for the test database""" + data_dict = get_test_data_dict() + + table = ApTable(list(data_dict.values()), names=list(data_dict.keys())) + return table + + +def ap_table_to_list(data: ApTable) -> list: + """Convert an astropy Table into a list of tuples.""" + rows = [] + for row in data: + rows.append(tuple(row)) + return rows + + +def create_database(schema: dict, db_filename: str): + """Create the test database""" + tbl_name = "ExposureInfo" + connection = sqlite3.connect(db_filename) + cursor = connection.cursor() + + create_table(cursor, tbl_name, schema["tables"][0]["columns"]) + + data = get_test_data_dict() + + for n in range(len(data["exposure_id"])): + row = tuple(data[key][n] for key in data.keys()) + value_str = "?, " * (len(row) - 1) + "?" + command = f"INSERT INTO {tbl_name} VALUES({value_str});" + cursor.execute(command, row) + connection.commit() + cursor.close() + + +class TableMismatchError(AssertionError): + pass + + +class RasTestCase(TestCase): + """Base class for tests in this package + + For now this only includes methods to check the + database results, but in the future other checks + might be put in place. + """ + + @staticmethod + def get_data_table_indices(table: list[tuple]) -> list[int]: + """Get the index for each rom in the data table. + + Parameters + ---------- + table : + The table containing the data. + + Returns + ------- + result : + The index for each row in the table. + """ + # Return the seq_num as an index + return [row[1] for row in table] + + def assertDataTableEqual(self, result, truth): + """Check if two data tables are equal. + + Parameters + ---------- + result : + The result generated by the test that is checked. + truth : + The expected value of the test. + """ + if len(result) != len(truth): + msg = "Data tables have a different number of rows: " + msg += f"indices: [{self.get_data_table_indices(result)}], [{self.get_data_table_indices(truth)}]" + raise TableMismatchError(msg) + try: + for n in range(len(truth)): + true_row = tuple(truth[n]) + row = tuple(result[n]) + self.assertTupleEqual(row, true_row) + except AssertionError: + msg = "Mismatched tables: " + msg += f"indices: [{self.get_data_table_indices(result)}], [{self.get_data_table_indices(truth)}]" + raise TableMismatchError(msg)