Skip to content

Commit

Permalink
Hide afw import so that tests can pass
Browse files Browse the repository at this point in the history
  • Loading branch information
fred3m committed Jul 7, 2024
1 parent 5f19a93 commit 9889a81
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 220 deletions.
10 changes: 7 additions & 3 deletions python/lsst/rubintv/analysis/service/commands/butler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING

from lsst.afw.cameraGeom import FOCAL_PLANE, Camera, Detector
from lsst.obs.lsst import Latiss, LsstCam, LsstComCam

from ..command import BaseCommand

if TYPE_CHECKING:
from ..data import DataCenter
from lsst.afw.cameraGeom import Camera


def get_camera(instrument_name: str) -> Camera:
Expand All @@ -46,6 +44,9 @@ def get_camera(instrument_name: str) -> Camera:
camera : Camera
The camera object.
"""
# Import afw packages here to prevent tests from failing
from lsst.obs.lsst import Latiss, LsstCam, LsstComCam

instrument_name = instrument_name.lower()
match instrument_name:
case "lsstcam":
Expand Down Expand Up @@ -73,6 +74,9 @@ class LoadDetectorInfoCommand(BaseCommand):
response_type: str = "detector_info"

def build_contents(self, data_center: DataCenter) -> dict:
# Import afw packages here to prevent tests from failing
from lsst.afw.cameraGeom import FOCAL_PLANE, Detector

# Load the detector information from the Butler
camera = get_camera(self.instrument)
detector_info = {}
Expand Down
6 changes: 3 additions & 3 deletions python/lsst/rubintv/analysis/service/commands/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING

from lsst.afw.cameraGeom import FOCAL_PLANE
from lsst.obs.lsst import Latiss, LsstCam, LsstComCam, LsstComCamSim

from ..command import BaseCommand
from ..database import exposure_tables, visit1_tables
from ..query import EqualityQuery, ParentQuery, Query
Expand Down Expand Up @@ -170,6 +167,9 @@ class LoadInstrumentCommand(BaseCommand):
response_type: str = "instrument info"

def build_contents(self, data_center: DataCenter) -> dict:
from lsst.afw.cameraGeom import FOCAL_PLANE
from lsst.obs.lsst import Latiss, LsstCam, LsstComCam, LsstComCamSim

instrument = self.instrument.lower()

match instrument:
Expand Down
189 changes: 68 additions & 121 deletions python/lsst/rubintv/analysis/service/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from typing import cast

import sqlalchemy

Expand Down Expand Up @@ -82,78 +80,68 @@ def get_table_schema(schema: dict, table: str) -> dict:
raise UnrecognizedTableError("Could not find the table '{table}' in database")


class Join(ABC):
"""A join between two tables in a database.
class EnhancedJoinBuilder:
def __init__(self, tables: dict[str, sqlalchemy.Table], joins: list[dict]):
self.tables = tables
self.joins = joins
self.join_graph = self._build_join_graph()

def _build_join_graph(self) -> dict[str, dict[str, list[str]]]:
graph = {table: {} for table in self.tables}
for join in self.joins:
tables = list(join["matches"].keys())
t1, t2 = tables[0], tables[1]
join_columns = list(zip(join["matches"][t1], join["matches"][t2]))
graph[t1][t2] = join_columns
graph[t2][t1] = [(col2, col1) for col1, col2 in join_columns]
return graph

def _find_join_path(self, start: str, end: str) -> list[str]:
queue = [(start, [start])]
visited = set()

while queue:
(node, path) = queue.pop(0)
if node not in visited:
if node == end:
return path
visited.add(node)
for neighbor in self.join_graph[node]:
if neighbor not in visited:
queue.append((neighbor, path + [neighbor]))
return []

Attributes
----------
join_type :
The type of join. For now only "inner" joins are supported.
"""

join_type: str

@abstractmethod
def __call__(self, database: ConsDbSchema):
pass


class InnerJoin(Join):
"""An inner join between two tables in a database.
Attributes
----------
n_columns :
The number of columns in the join.
matches :
Dictionary with table names as keys and tuples of column names as
values in the order in which they are matched in the join.
"""

n_columns: int
matches: dict[str, tuple[str, ...]]

def __init__(self, matches: dict[str, tuple[str, ...]]):
self.join_type = "inner"
if len(matches) != 2:
raise ValueError(f"Inner joins must have exactly two tables: got {len(matches)}")

n_columns = 0
for _, fields in matches.items():
if n_columns == 0:
n_columns = len(fields)
else:
if n_columns != len(fields):
raise ValueError(
"Inner joins must have the same number of fields for each table: "
f"got {n_columns} and {len(fields)}"
)
self.n_columns = n_columns
self.matches = matches
def build_join(self, table_names: set[str]) -> sqlalchemy.Table | sqlalchemy.Join:
tables = list(table_names)
select_from = self.tables[tables[0]]

def __call__(self, database: ConsDbSchema):
"""Create the sqlalchemy join between the two tables.
for i in range(1, len(tables)):
current_table = tables[i]
previous_table = tables[i - 1]
join_path = self._find_join_path(previous_table, current_table)

if not join_path:
raise ValueError(f"No join path found between {previous_table} and {current_table}")

for j in range(1, len(join_path)):
t1, t2 = join_path[j - 1], join_path[j]
join_conditions = []
for col1, col2 in self.join_graph[t1][t2]:
try:
condition = self.tables[t1].columns[col1] == self.tables[t2].columns[col2]
join_conditions.append(condition)
except KeyError as e:
logger.error(f"Column not found: {e}")
logger.error(f"Available columns in {t1}: {self.tables[t1].columns.keys()}")
logger.error(f"Available columns in {t2}: {self.tables[t2].columns.keys()}")
raise

if not join_conditions:
raise ValueError(f"No valid join conditions found between {t1} and {t2}")

select_from = sqlalchemy.join(select_from, self.tables[t2], *join_conditions)

Parameters
----------
database :
The database connection.
"""
tables = tuple(self.matches.keys())
table1 = tables[0]
table2 = tables[1]
table_model1 = database.tables[table1]
table_model2 = database.tables[table2]
joins = []
print("matches is", self.matches)
print("table1 is", table1)
print("table2 is", table2)
for index in range(self.n_columns):
joins.append(
table_model1.columns[self.matches[table1][index]]
== table_model2.columns[self.matches[table2][index]]
)
return sqlalchemy.and_(*joins)
return select_from


class ConsDbSchema:
Expand All @@ -175,23 +163,13 @@ class ConsDbSchema:
schema: dict
metadata: sqlalchemy.MetaData
tables: dict[str, sqlalchemy.Table]
joins: dict[str, tuple[Join, ...]]
joins: EnhancedJoinBuilder

def __init__(self, engine: sqlalchemy.engine.Engine, schema: dict, join_templates: list):
self.engine = engine
self.schema = schema
self.metadata = sqlalchemy.MetaData()

joins = {}
for join in join_templates:
if join["type"] == "inner":
if "inner" not in joins:
joins["inner"] = []
joins["inner"].append(InnerJoin(join["matches"]))
else:
raise NotImplementedError(f"Join type {join['type']} is not implemented")
self.joins = {key: tuple(value) for key, value in joins.items()}

self.tables = {}
for table in schema["tables"]:
if (
Expand All @@ -209,6 +187,8 @@ def __init__(self, engine: sqlalchemy.engine.Engine, schema: dict, join_template
schema=schema["name"],
)

self.joins = EnhancedJoinBuilder(self.tables, join_templates)

def get_table_names(self) -> tuple[str, ...]:
"""Given a schema, return a list of dataset names
Expand Down Expand Up @@ -267,41 +247,6 @@ def get_column(self, column: str) -> sqlalchemy.Column:
table, column = column.split(".")
return self.tables[table].columns[column]

def get_join(self, table1: str, table2: str) -> sqlalchemy.ColumnElement:
"""Return the join between two tables.
Parameters
----------
table1 :
The first table in the join.
table2 :
The second table in the join.
Returns
-------
result :
The join between the two tables.
"""
joins = cast(tuple[InnerJoin, ...], self.joins["inner"])
for join in joins:
tables = join.matches.keys()
if table1 in tables and table2 in tables:
return join(self)

raise ValueError(f"Could not find a join between {table1} and {table2}")

def build_join(self, table_names: set[str]) -> sqlalchemy.Table | sqlalchemy.Join:
tables = list(table_names)
select_from = self.tables[tables[0]]
print("tables are", tables)
for i in range(1, len(tables)):
current_table = tables[i]
previous_table = tables[i-1]
print("current:", current_table, "previous:", previous_table)
join = self.get_join(previous_table, current_table)
select_from = sqlalchemy.join(select_from, self.tables[current_table], join)
return select_from

def fetch_data(self, query_model: sqlalchemy.Select) -> dict[str, list]:
# Temporary, for testing. TODO: remove this code block before merging
_log_level = logger.getEffectiveLevel()
Expand Down Expand Up @@ -351,10 +296,12 @@ def add_data_ids(table_name: str) -> list[sqlalchemy.Column]:
table_columns.add(seq_num_column.label("seq_num"))
return [day_obs_column, seq_num_column]

if "visit1" in table_names:
if list(table_names)[0] in visit1_tables:
data_id_columns = add_data_ids("visit1")
else:
elif list(table_names)[0] in exposure_tables:
data_id_columns = add_data_ids("exposure")
else:
raise ValueError(f"Unsupported table name: {list(table_names)[0]}")

return table_columns, table_names, data_id_columns

Expand Down Expand Up @@ -389,13 +336,13 @@ def query(
if query is not None:
query_result = query(self)
query_model = sqlalchemy.and_(query_model, query_result.result)
table_names.add(*query_result.tables)
table_names.update(query_result.tables)
if data_ids is not None:
data_id_select = sqlalchemy.tuple_(day_obs_column, seq_num_column).in_(data_ids)
query_model = sqlalchemy.and_(query_model, data_id_select)

# Build the join
select_from = self.build_join(table_names)
select_from = self.joins.build_join(table_names)

# Build the query
query_model = sqlalchemy.select(*table_columns).select_from(select_from).where(query_model)
Expand Down
Loading

0 comments on commit 9889a81

Please sign in to comment.