Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-46363: Inject SqlQueryContext into ObsCoreManager at top level #1081

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions python/lsst/daf/butler/registry/interfaces/_obscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@
if TYPE_CHECKING:
from lsst.sphgeom import Region

from ..._column_type_info import ColumnTypeInfo
from ..._dataset_ref import DatasetRef
from ...dimensions import DimensionUniverse
from ..queries import SqlQueryContext
from ._collections import CollectionRecord
from ._database import Database, StaticTablesContext
from ._datasets import DatasetRecordStorageManager
Expand Down Expand Up @@ -103,6 +103,7 @@ def initialize(
datasets: type[DatasetRecordStorageManager],
dimensions: DimensionRecordStorageManager,
registry_schema_version: VersionTuple | None = None,
column_type_info: ColumnTypeInfo,
) -> ObsCoreTableManager:
"""Construct an instance of the manager.

Expand All @@ -124,6 +125,9 @@ def initialize(
Manager for Registry dimensions.
registry_schema_version : `VersionTuple` or `None`
Schema version of this extension as defined in registry.
column_type_info : `ColumnTypeInfo`
Information about column types that can differ between data
repositories and registry instances.

Returns
-------
Expand All @@ -144,7 +148,7 @@ def config_json(self) -> str:
raise NotImplementedError()

@abstractmethod
def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
def add_datasets(self, refs: Iterable[DatasetRef]) -> int:
"""Possibly add datasets to the obscore table.

This method should be called when new datasets are added to a RUN
Expand All @@ -156,8 +160,6 @@ def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) ->
Dataset refs to add. Dataset refs have to be completely expanded.
If a record with the same dataset ID is already in obscore table,
the dataset is ignored.
context : `SqlQueryContext`
Context used to execute queries for additional dimension metadata.

Returns
-------
Expand All @@ -180,9 +182,7 @@ def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) ->
raise NotImplementedError()

@abstractmethod
def associate(
self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext
) -> int:
def associate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
"""Possibly add datasets to the obscore table.

This method should be called when existing datasets are associated with
Expand All @@ -196,8 +196,6 @@ def associate(
the dataset is ignored.
collection : `CollectionRecord`
Collection record for a TAGGED collection.
context : `SqlQueryContext`
Context used to execute queries for additional dimension metadata.

Returns
-------
Expand Down
23 changes: 12 additions & 11 deletions python/lsst/daf/butler/registry/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,17 @@ def initialize(
universe=universe,
registry_schema_version=schema_versions.get("datastores"),
)
kwargs["column_types"] = ColumnTypeInfo(
database.getTimespanRepresentation(),
universe,
dataset_id_spec=types.datasets.addDatasetForeignKey(
dummy_table,
primaryKey=False,
nullable=False,
),
run_key_spec=types.collections.addRunForeignKey(dummy_table, primaryKey=False, nullable=False),
ingest_date_dtype=datasets.ingest_date_dtype(),
)
if types.obscore is not None and "obscore" in types.manager_configs:
kwargs["obscore"] = types.obscore.initialize(
database,
Expand All @@ -453,20 +464,10 @@ def initialize(
datasets=types.datasets,
dimensions=kwargs["dimensions"],
registry_schema_version=schema_versions.get("obscore"),
column_type_info=kwargs["column_types"],
)
else:
kwargs["obscore"] = None
kwargs["column_types"] = ColumnTypeInfo(
database.getTimespanRepresentation(),
universe,
dataset_id_spec=types.datasets.addDatasetForeignKey(
dummy_table,
primaryKey=False,
nullable=False,
),
run_key_spec=types.collections.addRunForeignKey(dummy_table, primaryKey=False, nullable=False),
ingest_date_dtype=datasets.ingest_date_dtype(),
)
kwargs["caching_context"] = caching_context
return cls(**kwargs)

Expand Down
36 changes: 24 additions & 12 deletions python/lsst/daf/butler/registry/obscore/_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
from lsst.utils.introspection import find_outside_stacklevel
from lsst.utils.iteration import chunk_iterable

from ..._column_type_info import ColumnTypeInfo
from ..interfaces import ObsCoreTableManager, VersionTuple
from ..queries import SqlQueryContext
from ._config import ConfigCollectionType, ObsCoreManagerConfig
from ._records import ExposureRegionFactory, Record, RecordFactory
from ._schema import ObsCoreSchema
Expand All @@ -57,7 +59,6 @@
DimensionRecordStorageManager,
StaticTablesContext,
)
from ..queries import SqlQueryContext

_VERSION = VersionTuple(0, 0, 1)

Expand All @@ -71,14 +72,16 @@ class _ExposureRegionFactory(ExposureRegionFactory):
The dimension records storage manager.
"""

def __init__(self, dimensions: DimensionRecordStorageManager):
def __init__(self, dimensions: DimensionRecordStorageManager, context: SqlQueryContext):
self.dimensions = dimensions
self.universe = dimensions.universe
self.exposure_dimensions = self.universe["exposure"].minimal_group
self.exposure_detector_dimensions = self.universe.conform(["exposure", "detector"])
self._context = context

def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None:
def exposure_region(self, dataId: DataCoordinate) -> Region | None:
# Docstring is inherited from a base class.
context = self._context
# Make a relation that starts with visit_definition (mapping between
# exposure and visit).
relation = context.make_initial_relation()
Expand Down Expand Up @@ -134,6 +137,9 @@ class ObsCoreLiveTableManager(ObsCoreTableManager):
Spatial plugins.
registry_schema_version : `VersionTuple` or `None`, optional
Version of registry schema.
column_type_info : `ColumnTypeInfo`
Information about column types that can differ between data
repositories and registry instances.
"""

def __init__(
Expand All @@ -147,6 +153,7 @@ def __init__(
dimensions: DimensionRecordStorageManager,
spatial_plugins: Collection[SpatialObsCorePlugin],
registry_schema_version: VersionTuple | None = None,
column_type_info: ColumnTypeInfo,
):
super().__init__(registry_schema_version=registry_schema_version)
self.db = db
Expand All @@ -155,7 +162,11 @@ def __init__(
self.universe = universe
self.config = config
self.spatial_plugins = spatial_plugins
exposure_region_factory = _ExposureRegionFactory(dimensions)
self._column_type_info = column_type_info
exposure_region_factory = _ExposureRegionFactory(
dimensions,
SqlQueryContext(self.db, column_type_info),
)
self.record_factory = RecordFactory(
config, schema, universe, spatial_plugins, exposure_region_factory
)
Expand Down Expand Up @@ -189,6 +200,7 @@ def clone(self, *, db: Database, dimensions: DimensionRecordStorageManager) -> O
# 'initialize'.
spatial_plugins=self.spatial_plugins,
registry_schema_version=self._registry_schema_version,
column_type_info=self._column_type_info,
)

@classmethod
Expand All @@ -202,6 +214,7 @@ def initialize(
datasets: type[DatasetRecordStorageManager],
dimensions: DimensionRecordStorageManager,
registry_schema_version: VersionTuple | None = None,
column_type_info: ColumnTypeInfo,
) -> ObsCoreTableManager:
# Docstring inherited from base class.
config_data = Config(config)
Expand All @@ -227,6 +240,7 @@ def initialize(
dimensions=dimensions,
spatial_plugins=spatial_plugins,
registry_schema_version=registry_schema_version,
column_type_info=column_type_info,
)

def config_json(self) -> str:
Expand All @@ -244,7 +258,7 @@ def currentVersions(cls) -> list[VersionTuple]:
# Docstring inherited from base class.
return [_VERSION]

def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
def add_datasets(self, refs: Iterable[DatasetRef]) -> int:
# Docstring inherited from base class.

# Only makes sense for RUN collection types
Expand Down Expand Up @@ -279,19 +293,17 @@ def add_datasets(self, refs: Iterable[DatasetRef], context: SqlQueryContext) ->
# Take all refs, no collection check.
obscore_refs = refs

return self._populate(obscore_refs, context)
return self._populate(obscore_refs)

def associate(
self, refs: Iterable[DatasetRef], collection: CollectionRecord, context: SqlQueryContext
) -> int:
def associate(self, refs: Iterable[DatasetRef], collection: CollectionRecord) -> int:
# Docstring inherited from base class.

# Only works when collection type is TAGGED
if self.tagged_collection is None:
return 0

if collection.name == self.tagged_collection:
return self._populate(refs, context)
return self._populate(refs)
else:
return 0

Expand All @@ -315,11 +327,11 @@ def disassociate(self, refs: Iterable[DatasetRef], collection: CollectionRecord)
count += self.db.deleteWhere(self.table, where)
return count

def _populate(self, refs: Iterable[DatasetRef], context: SqlQueryContext) -> int:
def _populate(self, refs: Iterable[DatasetRef]) -> int:
"""Populate obscore table with the data from given datasets."""
records: list[Record] = []
for ref in refs:
record = self.record_factory(ref, context)
record = self.record_factory(ref)
if record is not None:
records.append(record)

Expand Down
11 changes: 3 additions & 8 deletions python/lsst/daf/butler/registry/obscore/_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@
from ._schema import ObsCoreSchema
from ._spatial import SpatialObsCorePlugin

if TYPE_CHECKING:
from ..queries import SqlQueryContext

_LOG = logging.getLogger(__name__)

# Map extra column type to a conversion method that takes string.
Expand All @@ -67,15 +64,13 @@ class ExposureRegionFactory:
"""Abstract interface for a class that returns a Region for an exposure."""

@abstractmethod
def exposure_region(self, dataId: DataCoordinate, context: SqlQueryContext) -> Region | None:
def exposure_region(self, dataId: DataCoordinate) -> Region | None:
"""Return a region for a given DataId that corresponds to an exposure.

Parameters
----------
dataId : `DataCoordinate`
Data ID for an exposure dataset.
context : `SqlQueryContext`
Context used to execute queries for additional dimension metadata.

Returns
-------
Expand Down Expand Up @@ -125,7 +120,7 @@ def __init__(
self.visit = universe["visit"]
self.physical_filter = cast(Dimension, universe["physical_filter"])

def __call__(self, ref: DatasetRef, context: SqlQueryContext) -> Record | None:
def __call__(self, ref: DatasetRef) -> Record | None:
"""Make an ObsCore record from a dataset.

Parameters
Expand Down Expand Up @@ -194,7 +189,7 @@ def __call__(self, ref: DatasetRef, context: SqlQueryContext) -> Record | None:
if (dimension_record := dataId.records[self.exposure.name]) is not None:
self._exposure_records(dimension_record, record)
if self.exposure_region_factory is not None:
region = self.exposure_region_factory.exposure_region(dataId, context)
region = self.exposure_region_factory.exposure_region(dataId)
elif self.visit.name in dataId and (dimension_record := dataId.records[self.visit.name]) is not None:
self._visit_records(dimension_record, record)

Expand Down
9 changes: 3 additions & 6 deletions python/lsst/daf/butler/registry/sql_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,8 +1078,7 @@ def insertDatasets(
try:
refs = list(storage.insert(runRecord, expandedDataIds, idGenerationMode))
if self._managers.obscore:
context = queries.SqlQueryContext(self._db, self._managers.column_types)
self._managers.obscore.add_datasets(refs, context)
self._managers.obscore.add_datasets(refs)
except sqlalchemy.exc.IntegrityError as err:
raise ConflictingDefinitionError(
"A database constraint failure was triggered by inserting "
Expand Down Expand Up @@ -1193,8 +1192,7 @@ def _importDatasets(
try:
refs = list(storage.import_(runRecord, expandedDatasets))
if self._managers.obscore:
context = queries.SqlQueryContext(self._db, self._managers.column_types)
self._managers.obscore.add_datasets(refs, context)
self._managers.obscore.add_datasets(refs)
except sqlalchemy.exc.IntegrityError as err:
raise ConflictingDefinitionError(
"A database constraint failure was triggered by inserting "
Expand Down Expand Up @@ -1307,8 +1305,7 @@ def associate(self, collection: str, refs: Iterable[DatasetRef]) -> None:
if self._managers.obscore:
# If a TAGGED collection is being monitored by ObsCore
# manager then we may need to save the dataset.
context = queries.SqlQueryContext(self._db, self._managers.column_types)
self._managers.obscore.associate(refsForType, collectionRecord, context)
self._managers.obscore.associate(refsForType, collectionRecord)
except sqlalchemy.exc.IntegrityError as err:
raise ConflictingDefinitionError(
f"Constraint violation while associating dataset of type {datasetType.name} with "
Expand Down
Loading