Skip to content

Commit

Permalink
Fix tests for consDB-like schema
Browse files Browse the repository at this point in the history
  • Loading branch information
fred3m committed Jul 5, 2024
1 parent 44cb862 commit 6ee00db
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 307 deletions.
62 changes: 21 additions & 41 deletions tests/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,18 @@ description: Small database for testing the package
joins:
- type: inner
matches:
Visit:
- day_obs
- seq_num
- instrument
ExposureInfo:
- day_obs
- seq_num
- instrument
exposure:
- exposure_id
visit1_quicklook:
- visit_id
tables:
- name: Visit
- name: exposure
index_columns:
- day_obs
- seq_num
- instrument
- exposure_id
columns:
- name: exposure_id
datatype: long
description: Unique identifier for the exposure.
- name: seq_num
datatype: long
description: Sequence number
Expand All @@ -29,9 +26,6 @@ tables:
observation date, as this is the night that the observations started,
so for observations after midnight obsStart and obsNight will be
different days.
- name: instrument
datatype: char
description: Instrument name
- name: ra
datatype: double
unit: degree
Expand All @@ -40,38 +34,24 @@ tables:
datatype: double
unit: degree
description: Declination of focal plane center
- name: ExposureInfo
index_columns:
- day_obs
- seq_num
- instrument
columns:
- name: seq_num
datatype: long
description: Sequence number
- name: day_obs
datatype: date
description: The night of the observation. This is different than the
observation date, as this is the night that the observations started,
so for observations after midnight obsStart and obsNight will be
different days.
- name: instrument
datatype: char
description: Instrument name
- name: exposure_id
datatype: long
description: Unique identifier of an exposure.
- name: expTime
datatype: double
description: Spatially-averaged duration of exposure, accurate to 10ms.
- name: physical_filter
datatype: char
description: ID of physical filter,
the filter associated with a particular instrument.
- name: obsStart
- name: obs_start
datatype: datetime
description: Start time of the exposure at the fiducial center
of the focal plane array, TAI, accurate to 10ms.
- name: obsStartMJD
- name: obs_start_mjd
datatype: double
description: Start of the exposure in MJD, TAI, accurate to 10ms.
- name: visit1_quicklook
index_columns:
- visit_id
columns:
- name: visit_id
datatype: long
description: Unique identifier for the visit.
- name: exp_time
datatype: double
description: Spatially-averaged duration of exposure, accurate to 10ms.
78 changes: 23 additions & 55 deletions tests/test_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,43 +20,14 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import json
import os
import tempfile
from typing import cast

import astropy.table
import lsst.rubintv.analysis.service as lras
import lsst.rubintv.analysis.service.database
import sqlalchemy
import utils
import yaml


class TestCommand(utils.RasTestCase):
def setUp(self):
path = os.path.dirname(__file__)
yaml_filename = os.path.join(path, "schema.yaml")

with open(yaml_filename) as file:
schema = yaml.safe_load(file)
db_file = tempfile.NamedTemporaryFile(delete=False)
utils.create_database(schema, db_file.name)
self.db_file = db_file
self.db_filename = db_file.name

# Load the database connection information
databases = {
"testdb": lsst.rubintv.analysis.service.database.ConsDbSchema(
schema=schema, engine=sqlalchemy.create_engine("sqlite:///" + db_file.name)
)
}

self.data_center = lras.data.DataCenter(schemas=databases)

def tearDown(self) -> None:
self.db_file.close()
os.remove(self.db_file.name)

def execute_command(self, command: dict, response_type: str) -> dict:
command_json = json.dumps(command)
response = lras.command.execute_command(command_json, self.data_center)
Expand All @@ -71,12 +42,12 @@ def test_calculate_bounds_command(self):
"name": "get bounds",
"parameters": {
"database": "testdb",
"column": "Visit.dec",
"column": "exposure.dec",
},
}
print(lras.command.BaseCommand.command_registry)
content = self.execute_command(command, "column bounds")
self.assertEqual(content["column"], "Visit.dec")
self.assertEqual(content["column"], "exposure.dec")
self.assertListEqual(content["bounds"], [-40, 50])


Expand All @@ -87,8 +58,8 @@ def test_load_full_columns(self):
"parameters": {
"database": "testdb",
"columns": [
"Visit.ra",
"Visit.dec",
"exposure.ra",
"exposure.dec",
],
},
}
Expand All @@ -98,15 +69,14 @@ def test_load_full_columns(self):

truth = cast(
astropy.table.Table,
utils.get_test_data("Visit")[
"Visit.ra",
"Visit.dec",
"Visit.day_obs",
"Visit.seq_num",
"Visit.instrument",
utils.get_test_data("exposure")[
"exposure.ra",
"exposure.dec",
"exposure.day_obs",
"exposure.seq_num",
],
)
valid = (truth["Visit.ra"] != None) & (truth["Visit.dec"] != None) # noqa: E711
valid = (truth["exposure.ra"] != None) & (truth["exposure.dec"] != None) # noqa: E711
truth = cast(astropy.table.Table, truth[valid])
self.assertDataTableEqual(data, truth)

Expand All @@ -116,14 +86,14 @@ def test_load_columns_with_query(self):
"parameters": {
"database": "testdb",
"columns": [
"ExposureInfo.exposure_id",
"Visit.ra",
"Visit.dec",
"visit1_quicklook.visit_id",
"exposure.ra",
"exposure.dec",
],
"query": {
"name": "EqualityQuery",
"content": {
"column": "ExposureInfo.expTime",
"column": "visit1_quicklook.exp_time",
"operator": "eq",
"value": 30,
},
Expand All @@ -134,26 +104,24 @@ def test_load_columns_with_query(self):
content = self.execute_command(command, "table columns")
data = content["data"]

visit_truth = utils.get_test_data("Visit")
exp_truth = utils.get_test_data("ExposureInfo")
visit_truth = utils.get_test_data("exposure")
exp_truth = utils.get_test_data("visit1_quicklook")
truth = astropy.table.join(
visit_truth,
exp_truth,
keys_left=("Visit.seq_num", "Visit.day_obs", "Visit.instrument"),
keys_right=("ExposureInfo.seq_num", "ExposureInfo.day_obs", "ExposureInfo.instrument"),
keys_left=("exposure.exposure_id",),
keys_right=("visit1_quicklook.visit_id",),
)
truth = truth[
"ExposureInfo.exposure_id",
"Visit.ra",
"Visit.dec",
"ExposureInfo.day_obs",
"ExposureInfo.seq_num",
"ExposureInfo.instrument",
"visit1_quicklook.visit_id",
"exposure.ra",
"exposure.dec",
"exposure.day_obs",
"exposure.seq_num",
]

# Select rows with expTime = 30
truth = truth[[True, True, False, False, False, True, False, False, False, False]]
print(data.keys())
self.assertDataTableEqual(data, truth)


Expand Down
51 changes: 25 additions & 26 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,60 +30,59 @@ def test_get_table_names(self):
self.assertTupleEqual(
table_names,
(
"Visit",
"ExposureInfo",
"exposure",
"visit1_quicklook",
),
)

def test_get_table_schema(self):
schema = lras.database.get_table_schema(self.database.schema, "ExposureInfo")
self.assertEqual(schema["name"], "ExposureInfo")
schema = lras.database.get_table_schema(self.database.schema, "exposure")
self.assertEqual(schema["name"], "exposure")

columns = [
"exposure_id",
"seq_num",
"day_obs",
"instrument",
"exposure_id",
"expTime",
"ra",
"dec",
"physical_filter",
"obsStart",
"obsStartMJD",
"obs_start",
"obs_start_mjd",
]
for n, column in enumerate(schema["columns"]):
self.assertEqual(column["name"], columns[n])

def test_single_table_query_columns(self):
truth = utils.get_test_data("Visit")
valid = (truth["Visit.ra"] != None) & (truth["Visit.dec"] != None) # noqa: E711
truth = utils.get_test_data("exposure")
valid = (truth["exposure.ra"] != None) & (truth["exposure.dec"] != None) # noqa: E711
truth = truth[valid]
truth = truth["Visit.ra", "Visit.dec", "Visit.day_obs", "Visit.seq_num", "Visit.instrument"]
data = self.database.query(columns=["Visit.ra", "Visit.dec"])
truth = truth["exposure.ra", "exposure.dec", "exposure.day_obs", "exposure.seq_num"]
data = self.database.query(columns=["exposure.ra", "exposure.dec"])
self.assertDataTableEqual(data, truth) # type: ignore

def test_multiple_table_query_columns(self):
visit_truth = utils.get_test_data("Visit")
exp_truth = utils.get_test_data("ExposureInfo")
visit_truth = utils.get_test_data("exposure")
exp_truth = utils.get_test_data("visit1_quicklook")
truth = astropy.table.join(
visit_truth,
exp_truth,
keys_left=("Visit.seq_num", "Visit.day_obs", "Visit.instrument"),
keys_right=("ExposureInfo.seq_num", "ExposureInfo.day_obs", "ExposureInfo.instrument"),
keys_left=("exposure.exposure_id"),
keys_right=("visit1_quicklook.visit_id"),
)
valid = (truth["Visit.ra"] != None) & (truth["Visit.dec"] != None) # noqa: E711
valid = (truth["exposure.ra"] != None) & (truth["exposure.dec"] != None) # noqa: E711
truth = truth[valid]
truth = truth[
"Visit.ra",
"Visit.dec",
"ExposureInfo.exposure_id",
"Visit.day_obs",
"Visit.seq_num",
"Visit.instrument",
"exposure.ra",
"exposure.dec",
"visit1_quicklook.visit_id",
"exposure.day_obs",
"exposure.seq_num",
]

data = self.database.query(columns=["Visit.ra", "Visit.dec", "ExposureInfo.exposure_id"])
data = self.database.query(columns=["exposure.ra", "exposure.dec", "visit1_quicklook.visit_id"])

self.assertDataTableEqual(data, truth)

def test_calculate_bounds(self):
result = self.database.calculate_bounds("Visit.dec")
result = self.database.calculate_bounds("exposure.dec")
self.assertTupleEqual(result, (-40, 50))
Loading

0 comments on commit 6ee00db

Please sign in to comment.