Skip to content

Commit

Permalink
Merge pull request #1081 from lsst/tickets/DM-46363
Browse files Browse the repository at this point in the history
DM-46363: Inject SqlQueryContext into ObsCoreManager at top level
  • Loading branch information
dhirving authored Sep 17, 2024
2 parents 910bec6 + 4627fb8 commit ec1fe6a
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 46 deletions.
16 changes: 7 additions & 9 deletions python/lsst/daf/butler/registry/interfaces/_obscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@
if TYPE_CHECKING:
from lsst.sphgeom import Region

from ..._column_type_info import ColumnTypeInfo
from ..._dataset_ref import DatasetRef
from ...dimensions import DimensionUniverse
from ..queries import SqlQueryContext
from ._collections import CollectionRecord
from ._database import Database, StaticTablesContext
from ._datasets import DatasetRecordStorageManager
Expand Down Expand Up @@ -103,6 +103,7 @@ def initialize(
datasets: type[DatasetRecordStorageManager],
dimensions: DimensionRecordStorageManager,
registry_schema_version: VersionTuple | None = None,
column_type_info: ColumnTypeInfo,
) -> ObsCoreTableManager:
"""Construct an instance of the manager.
Expand All @@ -124,6 +125,9 @@ def initialize(
Manager for Registry dimensions.
registry_schema_version : `VersionTuple` or `None`
Schema version of this extension as defined in registry.
column_type_info : `ColumnTypeInfo`
Information about column types that can differ between data
repositories and registry instances.
Returns
-------
Expand All @@ -144,7 +148,7 @@ def config_json(self) -> str:
raise NotImplementedError()

@abstractmethod
def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
def add_datasets(self, refs: Iterable[DatasetRef]) -> int:
"""Possibly add datasets to the obscore table.
This method should be called when new datasets are added to a RUN
Expand All @@ -156,8 +160,6 @@ def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) ->
Dataset refs to add. Dataset refs have to be completely expanded.
If a record with the same dataset ID is already in obscore table,
the dataset is ignored.
context : `SqlQueryContext`
Context used to execute queries for additional dimension metadata.
Returns
-------
Expand All @@ -180,9 +182,7 @@ def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) ->
raise NotImplementedError()

@abstractmethod
def associate(
self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext
) -> int:
def associate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
"""Possibly add datasets to the obscore table.
This method should be called when existing datasets are associated with
Expand All @@ -196,8 +196,6 @@ def associate(
the dataset is ignored.
collection : `CollectionRecord`
Collection record for a TAGGED collection.
context : `SqlQueryContext`
Context used to execute queries for additional dimension metadata.
Returns
-------
Expand Down
23 changes: 12 additions & 11 deletions python/lsst/daf/butler/registry/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,17 @@ def initialize(
universe=universe,
registry_schema_version=schema_versions.get("datastores"),
)
kwargs["column_types"] = ColumnTypeInfo(
database.getTimespanRepresentation(),
universe,
dataset_id_spec=types.datasets.addDatasetForeignKey(
dummy_table,
primaryKey=False,
nullable=False,
),
run_key_spec=types.collections.addRunForeignKey(dummy_table, primaryKey=False, nullable=False),
ingest_date_dtype=datasets.ingest_date_dtype(),
)
if types.obscore is not None and "obscore" in types.manager_configs:
kwargs["obscore"] = types.obscore.initialize(
database,
Expand All @@ -453,20 +464,10 @@ def initialize(
datasets=types.datasets,
dimensions=kwargs["dimensions"],
registry_schema_version=schema_versions.get("obscore"),
column_type_info=kwargs["column_types"],
)
else:
kwargs["obscore"] = None
kwargs["column_types"] = ColumnTypeInfo(
database.getTimespanRepresentation(),
universe,
dataset_id_spec=types.datasets.addDatasetForeignKey(
dummy_table,
primaryKey=False,
nullable=False,
),
run_key_spec=types.collections.addRunForeignKey(dummy_table, primaryKey=False, nullable=False),
ingest_date_dtype=datasets.ingest_date_dtype(),
)
kwargs["caching_context"] = caching_context
return cls(**kwargs)

Expand Down
36 changes: 24 additions & 12 deletions python/lsst/daf/butler/registry/obscore/_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
from lsst.utils.introspection import find_outside_stacklevel
from lsst.utils.iteration import chunk_iterable

from ..._column_type_info import ColumnTypeInfo
from ..interfaces import ObsCoreTableManager, VersionTuple
from ..queries import SqlQueryContext
from ._config import ConfigCollectionType, ObsCoreManagerConfig
from ._records import ExposureRegionFactory, Record, RecordFactory
from ._schema import ObsCoreSchema
Expand All @@ -57,7 +59,6 @@
DimensionRecordStorageManager,
StaticTablesContext,
)
from ..queries import SqlQueryContext

_VERSION = VersionTuple(0, 0, 1)

Expand All @@ -71,14 +72,16 @@ class _ExposureRegionFactory(ExposureRegionFactory):
The dimension records storage manager.
"""

def __init__(self, dimensions: DimensionRecordStorageManager):
def __init__(self, dimensions: DimensionRecordStorageManager, context: SqlQueryContext):
self.dimensions = dimensions
self.universe = dimensions.universe
self.exposure_dimensions = self.universe["exposure"].minimal_group
self.exposure_detector_dimensions = self.universe.conform(["exposure", "detector"])
self._context = context

def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None:
def exposure_region(self, dataId: DataCoordinate) -> Region | None:
# Docstring is inherited from a base class.
context = self._context
# Make a relation that starts with visit_definition (mapping between
# exposure and visit).
relation = context.make_initial_relation()
Expand Down Expand Up @@ -134,6 +137,9 @@ class ObsCoreLiveTableManager(ObsCoreTableManager):
Spatial plugins.
registry_schema_version : `VersionTuple` or `None`, optional
Version of registry schema.
column_type_info : `ColumnTypeInfo`
Information about column types that can differ between data
repositories and registry instances.
"""

def __init__(
Expand All @@ -147,6 +153,7 @@ def __init__(
dimensions: DimensionRecordStorageManager,
spatial_plugins: Collection[SpatialObsCorePlugin],
registry_schema_version: VersionTuple | None = None,
column_type_info: ColumnTypeInfo,
):
super().__init__(registry_schema_version=registry_schema_version)
self.db = db
Expand All @@ -155,7 +162,11 @@ def __init__(
self.universe = universe
self.config = config
self.spatial_plugins = spatial_plugins
exposure_region_factory = _ExposureRegionFactory(dimensions)
self._column_type_info = column_type_info
exposure_region_factory = _ExposureRegionFactory(
dimensions,
SqlQueryContext(self.db, column_type_info),
)
self.record_factory = RecordFactory(
config, schema, universe, spatial_plugins, exposure_region_factory
)
Expand Down Expand Up @@ -189,6 +200,7 @@ def clone(self, *, db: Database, dimensions: DimensionRecordStorageManager) -> O
# 'initialize'.
spatial_plugins=self.spatial_plugins,
registry_schema_version=self._registry_schema_version,
column_type_info=self._column_type_info,
)

@classmethod
Expand All @@ -202,6 +214,7 @@ def initialize(
datasets: type[DatasetRecordStorageManager],
dimensions: DimensionRecordStorageManager,
registry_schema_version: VersionTuple | None = None,
column_type_info: ColumnTypeInfo,
) -> ObsCoreTableManager:
# Docstring inherited from base class.
config_data = Config(config)
Expand All @@ -227,6 +240,7 @@ def initialize(
dimensions=dimensions,
spatial_plugins=spatial_plugins,
registry_schema_version=registry_schema_version,
column_type_info=column_type_info,
)

def config_json(self) -> str:
Expand All @@ -244,7 +258,7 @@ def currentVersions(cls) -> list[VersionTuple]:
# Docstring inherited from base class.
return [_VERSION]

def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
def add_datasets(self, refs: Iterable[DatasetRef]) -> int:
# Docstring inherited from base class.

# Only makes sense for RUN collection types
Expand Down Expand Up @@ -279,19 +293,17 @@ def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) ->
# Take all refs, no collection check.
obscore_refs = refs

return self._populate(obscore_refs, context)
return self._populate(obscore_refs)

def associate(
self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext
) -> int:
def associate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
# Docstring inherited from base class.

# Only works when collection type is TAGGED
if self.tagged_collection is None:
return 0

if collection.name == self.tagged_collection:
return self._populate(refs, context)
return self._populate(refs)
else:
return 0

Expand All @@ -315,11 +327,11 @@ def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord)
count += self.db.deleteWhere(self.table, where)
return count

def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
def _populate(self, refs: Iterable[DatasetRef]) -> int:
"""Populate obscore table with the data from given datasets."""
records: list[Record] = []
for ref in refs:
record = self.record_factory(ref, context)
record = self.record_factory(ref)
if record is not None:
records.append(record)

Expand Down
11 changes: 3 additions & 8 deletions python/lsst/daf/butler/registry/obscore/_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@
from ._schema import ObsCoreSchema
from ._spatial import SpatialObsCorePlugin

if TYPE_CHECKING:
from ..queries import SqlQueryContext

_LOG = logging.getLogger(__name__)

# Map extra column type to a conversion method that takes string.
Expand All @@ -67,15 +64,13 @@ class ExposureRegionFactory:
"""Abstract interface for a class that returns a Region for an exposure."""

@abstractmethod
def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None:
def exposure_region(self, dataId: DataCoordinate) -> Region | None:
"""Return a region for a given DataId that corresponds to an exposure.
Parameters
----------
dataId : `DataCoordinate`
Data ID for an exposure dataset.
context : `SqlQueryContext`
Context used to execute queries for additional dimension metadata.
Returns
-------
Expand Down Expand Up @@ -125,7 +120,7 @@ def __init__(
self.visit = universe["visit"]
self.physical_filter = cast(Dimension, universe["physical_filter"])

def __call__(self, ref: DatasetRef, context: SqlQueryContext) -> Record | None:
def __call__(self, ref: DatasetRef) -> Record | None:
"""Make an ObsCore record from a dataset.
Parameters
Expand Down Expand Up @@ -194,7 +189,7 @@ def __call__(self, ref: DatasetRef, context: SqlQueryContext) -> Record | None:
if (dimension_record := dataId.records[self.exposure.name]) is not None:
self._exposure_records(dimension_record, record)
if self.exposure_region_factory is not None:
region = self.exposure_region_factory.exposure_region(dataId, context)
region = self.exposure_region_factory.exposure_region(dataId)
elif self.visit.name in dataId and (dimension_record := dataId.records[self.visit.name]) is not None:
self._visit_records(dimension_record, record)

Expand Down
9 changes: 3 additions & 6 deletions python/lsst/daf/butler/registry/sql_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,8 +1078,7 @@ def insertDatasets(
try:
refs = list(storage.insert(runRecord, expandedDataIds, idGenerationMode))
if self._managers.obscore:
context = queries.SqlQueryContext(self._db, self._managers.column_types)
self._managers.obscore.add_datasets(refs, context)
self._managers.obscore.add_datasets(refs)
except sqlalchemy.exc.IntegrityError as err:
raise ConflictingDefinitionError(
"A database constraint failure was triggered by inserting "
Expand Down Expand Up @@ -1193,8 +1192,7 @@ def _importDatasets(
try:
refs = list(storage.import_(runRecord, expandedDatasets))
if self._managers.obscore:
context = queries.SqlQueryContext(self._db, self._managers.column_types)
self._managers.obscore.add_datasets(refs, context)
self._managers.obscore.add_datasets(refs)
except sqlalchemy.exc.IntegrityError as err:
raise ConflictingDefinitionError(
"A database constraint failure was triggered by inserting "
Expand Down Expand Up @@ -1307,8 +1305,7 @@ def associate(self, collection: str, refs: Iterable[DatasetRef]) -> None:
if self._managers.obscore:
# If a TAGGED collection is being monitored by ObsCore
# manager then we may need to save the dataset.
context = queries.SqlQueryContext(self._db, self._managers.column_types)
self._managers.obscore.associate(refsForType, collectionRecord, context)
self._managers.obscore.associate(refsForType, collectionRecord)
except sqlalchemy.exc.IntegrityError as err:
raise ConflictingDefinitionError(
f"Constraint violation while associating dataset of type {datasetType.name} with "
Expand Down

0 comments on commit ec1fe6a

Please sign in to comment.