Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SITCOM-1834: Trim returned schema by removing the empty columns #11

Merged
merged 8 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions python/lsst/rubintv/analysis/service/commands/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,30 +181,34 @@ def build_contents(self, data_center: DataCenter) -> dict:
camera = Latiss.getCamera()
case "lsstcomcamsim":
camera = LsstComCamSim.getCamera()
case "testdb":
camera = None
case _:
raise ValueError(f"Unsupported instrument: {instrument}")

detectors = []
for detector in camera:
corners = [(c.getX(), c.getY()) for c in detector.getCorners(FOCAL_PLANE)]
detectors.append(
{
"id": detector.getId(),
"name": detector.getName(),
"corners": corners,
}
)
if camera is not None:
for detector in camera:
corners = [(c.getX(), c.getY()) for c in detector.getCorners(FOCAL_PLANE)]
detectors.append(
{
"id": detector.getId(),
"name": detector.getName(),
"corners": corners,
}
)

result = {
"instrument": self.instrument,
"detectors": detectors,
}

# Load the data base to access the schema
schema_name = f"cdb_{instrument}"
schema_name = f"cdb_{instrument}" if instrument != "testdb" else "testdb"
try:
database = data_center.schemas[schema_name]
result["schema"] = database.schema
result["schema"] = database.get_verified_schema()

except KeyError:
logger.warning(f"No database connection available for {schema_name}")
logger.warning(f"Available databases: {data_center.schemas.keys()}")
Expand Down
67 changes: 67 additions & 0 deletions python/lsst/rubintv/analysis/service/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,3 +563,70 @@ def calculate_bounds(self, column: str) -> tuple[float, float]:
else:
raise ValueError(f"Could not calculate the max of column {column}")
return col_min, col_max

def has_non_null_values(self, column: str) -> bool:
"""Check if a column contains any non-null values.

Parameters
----------
column : str
The column to check, in the format "table.column".

Returns
-------
bool
True if the column contains at least one non-null value, False
otherwise.
"""
try:
table_name, column_name = column.split(".")
if table_name not in self.tables:
logger.warning(f"Table '{table_name}' not found in database schema.")
return False

_table = self.tables[table_name]
_column = _table.columns[column_name]

# Query to check if at least one non-null value exists
query = sqlalchemy.select(_column).where(_column.isnot(None)).limit(1)

with self.engine.connect() as connection:
result = connection.execute(query).fetchone()

return result is not None # True if we found at least one non-null value

except Exception as e:
logger.error(f"Error checking non-null values for column '{column}': {e}")
return False

def get_verified_schema(self):
all_columns = [
f"{table['name']}.{column['name']}"
for table in self.schema.get("tables", [])
for column in table.get("columns", [])
]
filtered_table_columns = [column for column in all_columns if self.has_non_null_values(column)]

if filtered_table_columns is None:
return self.schema # Return full schema if no filtering is needed

filtered_schema = self.schema.copy()
filtered_schema["tables"] = []

# Process tables dynamically
for table in self.schema.get("tables", []):
filtered_columns = [
column
for column in table.get("columns", [])
if f"{table['name']}.{column['name']}" in filtered_table_columns
]
if filtered_columns:
# Preserve all table metadata dynamically
filtered_table = {key: value for key, value in table.items() if key != "columns"}
filtered_table["columns"] = filtered_columns
filtered_schema["tables"].append(filtered_table)

if not filtered_columns:
logger.warning(f"All columns in {self.schema['name']} are empty. Returning an empty schema.")

return filtered_schema
1 change: 0 additions & 1 deletion scripts/mock_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ def main():
for logger_name in [
"lsst.rubintv.analysis.service.worker",
"lsst.rubintv.analysis.service.client",
"lsst.rubintv.analysis.service.server",
]:
logger = logging.getLogger(logger_name)
logger.setLevel(log_level)
Expand Down
3 changes: 3 additions & 0 deletions tests/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,6 @@ tables:
- name: exp_time
datatype: double
description: Spatially-averaged duration of exposure, accurate to 10ms.
- name: empty_column
datatype: char
description: An empty column for testing purposes.
33 changes: 33 additions & 0 deletions tests/test_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import json
from typing import cast
from unittest.mock import MagicMock, patch

import astropy.table
import lsst.rubintv.analysis.service as lras
Expand All @@ -37,6 +38,38 @@ def execute_command(self, command: dict, response_type: str) -> dict:
return result["content"]


class TestLoadInstrumentCommand(TestCommand):
@patch.dict(
"sys.modules",
{
"lsst": MagicMock(),
"lsst.obs": MagicMock(),
"lsst.obs.lsst": MagicMock(),
"lsst.afw.cameraGeom": MagicMock(),
},
)
def test_load_instrument_command(self):
command = {
"name": "load instrument",
"parameters": {
"instrument": "testdb",
},
}
content = self.execute_command(command, response_type="instrument info")

# Check that empty columns are not included in the returned schema
data = utils.get_visit_data_dict()
if "visit1_quicklook.empty_column" in data:
del data["visit1_quicklook.empty_column"]

visit1_quicklook = [
table for table in content["schema"]["tables"] if table["name"] == "visit1_quicklook"
][0]
column_names = [f"visit1_quicklook.{column['name']}" for column in visit1_quicklook["columns"]]

self.assertListEqual(column_names, list(data.keys()))


class TestLoadColumnsWithAggregatorCommand(TestCommand):
def setUpTest(self):
"""
Expand Down
1 change: 1 addition & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def get_visit_data_dict() -> dict:
return {
"visit1_quicklook.visit_id": [2, 4, 6, 8, 10, 12, 14, 16, 18, 20],
"visit1_quicklook.exp_time": [30, 30, 10, 15, 15, 30, 30, 30, 15, 20],
"visit1_quicklook.empty_column": [None] * 10,
}


Expand Down
Loading